[Mlir-commits] [mlir] [mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (PR #66722)

Peiming Liu llvmlistbot at llvm.org
Mon Sep 18 17:08:24 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/66722

The functionality of the two operations are largely overlapped, let's simplify it and only use one of them.

>From 3daa6dc70c102f924f8e8e88d4eb8b605aff6f22 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Sep 2023 21:48:00 +0000
Subject: [PATCH 1/4] [mlir][sparse] unify sparse_tensor.sort/sort_coo
 operations into one.

---
 .../SparseTensor/IR/SparseTensorOps.td        |  53 +---
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  30 +--
 .../Transforms/SparseBufferRewriting.cpp      | 226 +++++++++---------
 .../Transforms/SparseTensorCodegen.cpp        |  12 +-
 .../Transforms/SparseTensorPasses.cpp         |   1 -
 .../Transforms/SparseTensorRewriting.cpp      |  38 ++-
 6 files changed, 137 insertions(+), 223 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 94301dbcd9f7b42..d83d1ba03feb848 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -762,61 +762,10 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
 // Sparse Tensor Sorting Operations.
 //===----------------------------------------------------------------------===//
 
-def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
-    Arguments<(ins Index:$n,
-               Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
-               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               SparseTensorSortKindAttr:$algorithm)>  {
-  string summary = "Sorts the arrays in xs and ys lexicographically on the "
-                   "integral values found in the xs list";
-  string description = [{
-    Lexicographically sort the first `n` values in `xs` along with the values in
-    `ys`. Conceptually, the values being sorted are tuples produced by
-    `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
-    along with values in `xs`, but values in `ys` don't affect the
-    lexicographical order. The order in which arrays appear in `xs` affects the
-    sorting result. The operator updates `xs` and `ys` in place with the result
-    of the sorting.
-
-    For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
-    "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
-    output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
-
-    Buffers in `xs` needs to have the same integral element type while buffers
-    in `ys` can have different numeric element types. All buffers in `xs` and
-    `ys` should have a dimension not less than `n`. The behavior of the operator
-    is undefined if this condition is not met. The operator requires at least
-    one buffer in `xs` while `ys` can be empty.
-
-    The enum attribute `algorithm` indicates the sorting algorithm used to
-    implement the operator: hybrid_quick_sort, insertion_sort_stable,
-    quick_sort, or heap_sort.
-
-    Note that this operation is "impure" in the sense that its behavior is
-    solely defined by side-effects and not SSA values.
-
-    Example:
-
-    ```mlir
-    sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
-      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
-    ```
-
-    ```mlir
-    sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
-      { alg=1 : index}
-      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
-    ```
-  }];
-  let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
-                       "`:` type($xs) (`jointly` type($ys)^)?";
-  let hasVerifier = 1;
-}
-
 def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+               AffineMapAttr:$nx, OptionalAttr<IndexAttr>:$ny,
                SparseTensorSortKindAttr:$algorithm)>  {
   let summary = "Sorts the arrays in xs and ys lexicographically on the "
                 "integral values found in the xs list";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e71d2a8dd623a6a..3cd0847bdf73765 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1353,34 +1353,6 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
-LogicalResult SortOp::verify() {
-  if (getXs().empty())
-    return emitError("need at least one xs buffer.");
-
-  std::optional<int64_t> n = getConstantIntValue(getN());
-
-  Type xtp = getMemRefType(getXs().front()).getElementType();
-  auto checkTypes = [&](ValueRange operands,
-                        bool checkEleType = true) -> LogicalResult {
-    for (Value opnd : operands) {
-      auto mtp = getMemRefType(opnd);
-      const DynSize sh = mtp.getShape()[0];
-      // We can't check the size of dynamic dimension at compile-time, but all
-      // xs and ys should have a dimension not less than n at runtime.
-      if (n && !ShapedType::isDynamic(sh) && sh < n.value())
-        return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
-                                       ": {0} < {1}",
-                                       sh, n.value()));
-
-      if (checkEleType && xtp != mtp.getElementType())
-        return emitError("mismatch xs element types");
-    }
-    return success();
-  };
-  RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
-  return n ? checkTypes(getYs(), false) : success();
-}
-
 LogicalResult SortCooOp::verify() {
   std::optional<int64_t> cn = getConstantIntValue(getN());
   // We can't check the size of the buffers when n or buffer dimensions aren't
@@ -1391,7 +1363,7 @@ LogicalResult SortCooOp::verify() {
   uint64_t n = cn.value();
   uint64_t nx = 1;
   if (auto nxAttr = getNxAttr()) {
-    nx = nxAttr.getInt();
+    nx = nxAttr.getAffineMap().getNumResults();
     if (nx < 1)
       emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 029ecb0708941fe..255c78f8b4eb61f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -46,29 +46,29 @@ static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
 
 using FuncGeneratorType = function_ref<void(
-    OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
+    OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, bool, uint32_t)>;
 
 /// Constructs a function name with this format to facilitate quick sort:
-///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
-///   <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
+///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
+///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
-                                         StringRef namePrefix, uint64_t nx,
+                                         StringRef namePrefix, AffineMap xPerm,
                                          uint64_t ny, bool isCoo,
                                          ValueRange operands) {
-  nameOstream << namePrefix << nx << "_"
+  nameOstream << namePrefix << xPerm << "_"
               << getMemRefType(operands[xStartIdx]).getElementType();
 
   if (isCoo)
     nameOstream << "_coo_" << ny;
 
-  uint64_t yBufferOffset = isCoo ? 1 : nx;
+  uint64_t yBufferOffset = isCoo ? 1 : xPerm.getNumResults();
   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
     nameOstream << "_" << getMemRefType(v).getElementType();
 }
 
 /// Looks up a function that is appropriate for the given operands being
 /// sorted, and creates such a function if it doesn't exist yet. The
-/// parameters `nx` and `ny` tell the number of x and y values provided
+/// parameters `xPerm` and `ny` tell the number of x and y values provided
 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
 //
@@ -79,12 +79,12 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
 static FlatSymbolRefAttr
 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
                          TypeRange resultTypes, StringRef namePrefix,
-                         uint64_t nx, uint64_t ny, bool isCoo,
+                         AffineMap xPerm, uint64_t ny, bool isCoo,
                          ValueRange operands, FuncGeneratorType createFunc,
                          uint32_t nTrailingP = 0) {
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
-  getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, isCoo,
                                operands.drop_back(nTrailingP));
 
   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +101,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
         loc, nameOstream.str(),
         FunctionType::get(context, operands.getTypes(), resultTypes));
     func.setPrivate();
-    createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
+    createFunc(builder, module, func, xPerm, ny, isCoo, nTrailingP);
   }
 
   return result;
@@ -110,15 +110,17 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
 /// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInXs(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny, bool isCoo,
+    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
   Value iOffset, jOffset;
   if (isCoo) {
-    Value cstep = constantIndex(builder, loc, nx + ny);
+    Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
     iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
     jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
   }
-  for (uint64_t k = 0; k < nx; k++) {
+  for (AffineExpr e : xPerm.getResults()) {
+    unsigned k = e.cast<AffineDimExpr>().getPosition();
     scf::IfOp ifOp;
     Value i, j, buffer;
     if (isCoo) {
@@ -138,21 +140,30 @@ static void forEachIJPairInXs(
 /// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInAllBuffers(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny, bool isCoo,
+    function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
 
-  // Create code for the first (nx + ny) buffers. When isCoo==true, these
+  // Create code for the first (xPerm + ny) buffers. When isCoo==true, these
   // logical buffers are all from the xy buffer of the sort_coo operator.
-  forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+  SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
+                               xPerm.getResults().end());
+  for (unsigned y = 0; y < ny; y++) {
+    exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
+  }
+  AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
+  assert(xyPerm.isPermutation());
+
+  forEachIJPairInXs(builder, loc, args, xyPerm, 0, isCoo, bodyBuilder);
 
-  uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+  uint64_t numHandledBuffers = isCoo ? 1 : xPerm.getNumResults() + ny;
 
   // Create code for the remaining buffers.
   Value i = args[0];
   Value j = args[1];
   for (const auto &arg :
        llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
-    bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+    bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
   }
 }
 
@@ -168,7 +179,7 @@ static void forEachIJPairInAllBuffers(
 //     ...
 //     swap(yn[i], yn[j]);
 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
-                       uint64_t nx, uint64_t ny, bool isCoo) {
+                       AffineMap xPerm, uint64_t ny, bool isCoo) {
   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -176,20 +187,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
     builder.create<memref::StoreOp>(loc, vi, buffer, j);
   };
 
-  forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
+  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, isCoo, swapOnePair);
 }
 
 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
 /// each pair is create via `compareBuilder`.
 static Value createInlinedCompareImplementation(
-    OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
-    bool isCoo,
+    OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+    uint64_t ny, bool isCoo,
     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
         compareBuilder) {
   Value result;
   auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
     bool isFirstDim = (k == 0);
-    bool isLastDim = (k == nx - 1);
+    bool isLastDim = (k == xPerm.getNumResults() - 1);
     Value val =
         compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
     if (isFirstDim) {
@@ -202,7 +213,7 @@ static Value createInlinedCompareImplementation(
     }
   };
 
-  forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
+  forEachIJPairInXs(builder, loc, args, xPerm, ny, isCoo, bodyBuilder);
 
   builder.setInsertionPointAfterValue(result);
   return result;
@@ -252,13 +263,14 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
 //     else if (x2[2] != x2[j]))
 //       and so on ...
 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
-                                    ValueRange args, uint64_t nx, uint64_t ny,
-                                    bool isCoo, uint32_t nTrailingP = 0) {
+                                    ValueRange args, AffineMap xPerm,
+                                    uint64_t ny, bool isCoo,
+                                    uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
-                                            createEqCompare);
+  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
+                                            isCoo, createEqCompare);
 }
 
 /// Generates code to compare whether x[i] is less than x[j] and returns the
@@ -306,13 +318,14 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
 //   else
 //       and so on ...
 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
-                                   ValueRange args, uint64_t nx, uint64_t ny,
-                                   bool isCoo, uint32_t nTrailingP = 0) {
+                                   ValueRange args, AffineMap xPerm,
+                                   uint64_t ny, bool isCoo,
+                                   uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
-  return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
-                                            createLessThanCompare);
+  return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
+                                            isCoo, createLessThanCompare);
 }
 
 /// Creates a function to use a binary search to find the insertion point for
@@ -329,8 +342,9 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
 //   return lo;
 //
 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
-                                   func::FuncOp func, uint64_t nx, uint64_t ny,
-                                   bool isCoo, uint32_t nTrailingP = 0) {
+                                   func::FuncOp func, AffineMap xPerm,
+                                   uint64_t ny, bool isCoo,
+                                   uint32_t nTrailingP = 0) {
   // Binary search doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -368,11 +382,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 
   // Compare xs[p] < xs[mid].
   SmallVector<Value> compareOperands{p, mid};
-  uint64_t numXBuffers = isCoo ? 1 : nx;
+  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
   Value cond2 =
-      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
   // Update lo and hi for the WhileOp as follows:
   //   if (xs[p] < xs[mid]))
   //     hi = mid;
@@ -394,7 +408,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 /// xs[i] == xs[p].
 static std::pair<Value, Value>
 createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
-               ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
+               ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny,
                bool isCoo, int step) {
   Location loc = func.getLoc();
   scf::WhileOp whileOp =
@@ -414,7 +428,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   }
   compareOperands.append(xs.begin(), xs.end());
   Value cond =
-      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
   Block *after =
@@ -429,7 +443,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   compareOperands[0] = i;
   compareOperands[1] = p;
   Value compareEq =
-      createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);
+      createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny, isCoo);
 
   return std::make_pair(whileOp.getResult(0), compareEq);
 }
@@ -438,7 +452,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
 /// right after the swap instructions.
 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
-                                       uint64_t nx, uint64_t ny, bool isCoo,
+                                       AffineMap xPerm, uint64_t ny, bool isCoo,
                                        SmallVectorImpl<Value> &swapOperands,
                                        SmallVectorImpl<Value> &compareOperands,
                                        Value a, Value b) {
@@ -446,59 +460,59 @@ static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
   compareOperands[0] = b;
   compareOperands[1] = a;
   Value cond =
-      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   swapOperands[0] = b;
   swapOperands[1] = a;
-  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
   return ifOp;
 }
 
 /// Creates code to insert the 3rd element to a list of two sorted elements.
-static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx,
+static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
                             uint64_t ny, bool isCoo,
                             SmallVectorImpl<Value> &swapOperands,
                             SmallVectorImpl<Value> &compareOperands, Value v0,
                             Value v1, Value v2) {
-  scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
+  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
                                          swapOperands, compareOperands, v1, v2);
-  createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands,
+  createCompareThenSwap(builder, loc, xPerm, ny, isCoo, swapOperands,
                         compareOperands, v0, v1);
   builder.setInsertionPointAfter(ifOp);
 }
 
 /// Creates code to sort 3 elements.
-static void createSort3(OpBuilder &builder, Location loc, uint64_t nx,
+static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
                         uint64_t ny, bool isCoo,
                         SmallVectorImpl<Value> &swapOperands,
                         SmallVectorImpl<Value> &compareOperands, Value v0,
                         Value v1, Value v2) {
   // Sort the first 2 elements.
   scf::IfOp ifOp1 = createCompareThenSwap(
-      builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1);
+      builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0, v1);
   builder.setInsertionPointAfter(ifOp1);
 
   // Insert the 3th element.
-  createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
+  createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands,
                   v0, v1, v2);
 }
 
 /// Creates code to sort 5 elements.
-static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
+static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
                         uint64_t ny, bool isCoo,
                         SmallVectorImpl<Value> &swapOperands,
                         SmallVectorImpl<Value> &compareOperands, Value v0,
                         Value v1, Value v2, Value v3, Value v4) {
   // Sort the first 3 elements.
-  createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0,
+  createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0,
               v1, v2);
 
   auto insert4th = [&]() {
     scf::IfOp ifOp = createCompareThenSwap(
-        builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3);
-    createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
-                    v0, v1, v2);
+        builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v2, v3);
+    createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands,
+                    compareOperands, v0, v1, v2);
     builder.setInsertionPointAfter(ifOp);
   };
 
@@ -506,7 +520,7 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
   insert4th();
 
   // Insert the 5th element.
-  scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
+  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
                                          swapOperands, compareOperands, v3, v4);
   insert4th();
   builder.setInsertionPointAfter(ifOp);
@@ -517,11 +531,11 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
 /// the number of values in range [lo, hi) is more than a threshold, we also
 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
-                              func::FuncOp func, uint64_t nx, uint64_t ny,
+                              func::FuncOp func, AffineMap xPerm, uint64_t ny,
                               bool isCoo, Value lo, Value hi, Value mi,
                               ValueRange args) {
   SmallVector<Value> compareOperands{mi, lo};
-  uint64_t numXBuffers = isCoo ? 1 : nx;
+  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
   SmallVector<Value> swapOperands{mi, lo};
@@ -537,7 +551,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
 
   // When len < 1000, choose pivot from median of 3 values.
   builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
-  createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo,
+  createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
               mi, hi);
 
   // When len >= 1000, choose pivot from median of 5 values.
@@ -549,8 +563,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
   Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
   // Value b is the middle between [mi, hi].
   b = builder.create<arith::ShRUIOp>(loc, b, c1);
-  createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a,
-              mi, b, hi);
+  createSort5(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
+              a, mi, b, hi);
 
   builder.setInsertionPointAfter(lenIf);
 }
@@ -586,7 +600,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
 //   }
 // }
 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
-                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                 bool isCoo, uint32_t nTrailingP = 0) {
   // Quick sort partition doesn't use trailing parameters.
   (void)nTrailingP;
@@ -606,7 +620,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
 
   Value i = lo;
   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
-  createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
+  createChoosePivot(builder, module, func, xPerm, ny, isCoo, i, j, p, args);
   Value trueVal = constantI1(builder, loc, true); // The value for while (true)
   SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
@@ -628,14 +642,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   j = after->getArgument(1);
   p = after->getArgument(2);
 
-  uint64_t numXBuffers = isCoo ? 1 : nx;
+  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
   auto [iresult, iCompareEq] =
       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
-                     i, p, nx, ny, isCoo, 1);
+                     i, p, xPerm, ny, isCoo, 1);
   i = iresult;
   auto [jresult, jCompareEq] =
       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
-                     j, p, nx, ny, isCoo, -1);
+                     j, p, xPerm, ny, isCoo, -1);
   j = jresult;
 
   // If i < j:
@@ -645,7 +659,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   SmallVector<Value> swapOperands{i, j};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
   // If the pivot is moved, update p with the new pivot.
   Value icond =
       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -737,7 +751,7 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
 // }
 //
 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
-                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                 bool isCoo, uint32_t nTrailingP) {
   // The value n is passed in as a trailing parameter.
   assert(nTrailingP == 1);
@@ -768,7 +782,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
   Value c1 = constantIndex(builder, loc, 1);
   SmallVector<Value> compareOperands{start, start};
-  uint64_t numXBuffers = isCoo ? 1 : nx;
+  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
 
@@ -794,7 +808,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
     compareOperands[0] = lChildIdx;
     compareOperands[1] = rChildIdx;
     Value cond2 =
-        createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+        createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
     scf::IfOp if2 =
         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
     builder.setInsertionPointToStart(&if2.getThenRegion().front());
@@ -826,7 +840,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   compareOperands[0] = start;
   compareOperands[1] = childIdx;
   Value cond =
-      createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
   // The after-region of the WhileOp.
@@ -836,7 +850,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   childIdx = after->getArgument(2);
   SmallVector<Value> swapOperands{start, childIdx};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
   start = childIdx;
   Value cond2 =
       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
@@ -869,7 +883,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
 //      shiftdown(lo, lo, l-1)
 // }
 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
-                               func::FuncOp func, uint64_t nx, uint64_t ny,
+                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                bool isCoo, uint32_t nTrailingP) {
   // Heap sort function doesn't have trailing parameters.
   (void)nTrailingP;
@@ -897,7 +911,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
   shiftDownOperands.push_back(n);
   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
-      builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
+      builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, isCoo,
       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
                                shiftDownOperands);
@@ -912,7 +926,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
   SmallVector<Value> swapOperands{lo, loplm1};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
   shiftDownOperands[1] = lo;
   shiftDownOperands[shiftDownOperands.size() - 1] =
       builder.create<arith::SubIOp>(loc, l, c1);
@@ -928,7 +942,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
 /// the bigger partition to be processed by the enclosed while-loop.
 static std::pair<Value, Value>
 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
-                ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
+                ValueRange args, AffineMap xPerm, uint64_t ny, bool isCoo,
                 uint32_t nTrailingP) {
   MLIRContext *context = module.getContext();
   Location loc = func.getLoc();
@@ -937,7 +951,7 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
 
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
-      builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
+      builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
       ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
   Value p = builder
                 .create<func::CallOp>(loc, partitionFunc,
@@ -1008,8 +1022,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
 //   }
 // }
 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
-                                 func::FuncOp func, uint64_t nx, uint64_t ny,
-                                 bool isCoo, uint32_t nTrailingP) {
+                                 func::FuncOp func, AffineMap xPerm,
+                                 uint64_t ny, bool isCoo, uint32_t nTrailingP) {
   // Stable sort function doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -1034,8 +1048,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   SmallVector<Value> operands{lo, i};
   operands.append(args.begin() + xStartIdx, args.end());
   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
-      builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
-      ny, isCoo, operands, createBinarySearchFunc);
+      builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
+      xPerm, ny, isCoo, operands, createBinarySearchFunc);
   Value p = builder
                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
                                       operands)
@@ -1045,7 +1059,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   operands[0] = operands[1] = i;
   SmallVector<Value> d;
   forEachIJPairInAllBuffers(
-      builder, loc, operands, nx, ny, isCoo,
+      builder, loc, operands, xPerm, ny, isCoo,
       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
       });
@@ -1061,7 +1075,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   operands[1] = imj;
   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
   forEachIJPairInAllBuffers(
-      builder, loc, operands, nx, ny, isCoo,
+      builder, loc, operands, xPerm, ny, isCoo,
       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
         builder.create<memref::StoreOp>(loc, t, buffer, imj);
@@ -1071,7 +1085,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointAfter(forOpJ);
   operands[0] = operands[1] = p;
   forEachIJPairInAllBuffers(
-      builder, loc, operands, nx, ny, isCoo,
+      builder, loc, operands, xPerm, ny, isCoo,
       [&](uint64_t k, Value p, Value usused, Value buffer) {
         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
       });
@@ -1123,7 +1137,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
 // }
 //
 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
-                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
                                 bool isCoo, uint32_t nTrailingP) {
   assert(nTrailingP == 1 || nTrailingP == 0);
   bool isHybrid = (nTrailingP == 1);
@@ -1173,7 +1187,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     // When len <= limit.
     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
-        builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
+        builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, isCoo,
         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
                                  ValueRange(args).drop_back(nTrailingP));
@@ -1193,7 +1207,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     // When depth exceeds limit.
     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
-        builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
+        builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, isCoo,
         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
                                  ValueRange(args).drop_back(nTrailingP));
@@ -1202,8 +1216,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     // When depth doesn't exceed limit.
     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
     args.back() = depthLimit;
-    std::tie(lo, hi) =
-        createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+    std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
+                                       isCoo, nTrailingP);
     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
 
     builder.setInsertionPointAfter(depthIf);
@@ -1215,8 +1229,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     lo = lenIf.getResult(0);
     hi = lenIf.getResult(1);
   } else {
-    std::tie(lo, hi) =
-        createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+    std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
+                                       isCoo, nTrailingP);
   }
 
   // New [lo, hi) for the next while-loop iteration.
@@ -1229,7 +1243,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
 
 /// Implements the rewriting for operator sort and sort_coo.
 template <typename OpTy>
-LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
+LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
                                     uint64_t ny, bool isCoo,
                                     PatternRewriter &rewriter) {
   Location loc = op.getLoc();
@@ -1284,9 +1298,9 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
     break;
   }
 
-  FlatSymbolRefAttr func =
-      getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
-                               ny, isCoo, operands, funcGenerator, nTrailingP);
+  FlatSymbolRefAttr func = getMangledSortHelperFunc(
+      rewriter, insertPoint, TypeRange(), funcName, xPerm, ny, isCoo, operands,
+      funcGenerator, nTrailingP);
   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
   return success();
 }
@@ -1410,20 +1424,6 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
   bool enableBufferInitialization;
 };
 
-/// Sparse rewriting rule for the sort operator.
-struct SortRewriter : public OpRewritePattern<SortOp> {
-public:
-  using OpRewritePattern<SortOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(SortOp op,
-                                PatternRewriter &rewriter) const override {
-    SmallVector<Value> xys(op.getXs());
-    xys.append(op.getYs().begin(), op.getYs().end());
-    return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
-                                 /*isCoo=*/false, rewriter);
-  }
-};
-
 /// Sparse rewriting rule for the sort_coo operator.
 struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
 public:
@@ -1434,15 +1434,13 @@ struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
     SmallVector<Value> xys;
     xys.push_back(op.getXy());
     xys.append(op.getYs().begin(), op.getYs().end());
-    uint64_t nx = 1;
-    if (auto nxAttr = op.getNxAttr())
-      nx = nxAttr.getInt();
 
+    auto xPerm = op.getNx();
     uint64_t ny = 0;
     if (auto nyAttr = op.getNyAttr())
       ny = nyAttr.getInt();
 
-    return matchAndRewriteSortOp(op, xys, nx, ny,
+    return matchAndRewriteSortOp(op, xys, xPerm, ny,
                                  /*isCoo=*/true, rewriter);
   }
 };
@@ -1457,5 +1455,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
                                          bool enableBufferInitialization) {
   patterns.add<PushBackRewriter>(patterns.getContext(),
                                  enableBufferInitialization);
-  patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
+  patterns.add<SortCooRewriter>(patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 557c5c471c4a77c..4419c39c69927e9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -890,8 +890,9 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     // If the innermost level is ordered, we need to sort the coordinates
     // in the "added" array prior to applying the compression.
     if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
-      rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
-                              SparseTensorSortKind::HybridQuickSort);
+      rewriter.create<SortCooOp>(
+          loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
+          rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
     // While performing the insertions, we also need to reset the elements
     // of the values/filled-switch by only iterating over the set elements,
     // to ensure that the runtime complexity remains proportional to the
@@ -1486,9 +1487,10 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
       scf::IfOp ifOp =
           rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
-      rewriter.create<SortCooOp>(
-          loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank),
-          rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
+      auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
+      rewriter.create<SortCooOp>(loc, nse, xs, ValueRange{ys}, xPerm,
+                                 rewriter.getIndexAttr(0),
+                                 SparseTensorSortKind::HybridQuickSort);
       rewriter.setInsertionPointAfter(ifOp);
     }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index ca7d8a7850b0b19..7d2f0c7f139cda5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -207,7 +207,6 @@ struct SparseTensorCodegenPass
     ConversionTarget target(*ctx);
     // Most ops in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
-    target.addLegalOp<SortOp>();
     target.addLegalOp<SortCooOp>();
     target.addLegalOp<PushBackOp>();
     // Storage specifier outlives sparse tensor pipeline.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 47f7dad08c8c920..277903dc55b7432 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1206,29 +1206,23 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
       // Retrieve the values-array.
       Value y = genToValues(rewriter, loc, src);
       const auto encSrc = srcTp.getEncoding();
-      // Sort the COO tensor so that its elements are ordered via increasing
-      // coordinates for the storage ordering of the dst tensor.  Use SortCoo
-      // if the COO tensor has the same ordering as the dst tensor.
-      if (dimRank > 1 && srcTp.hasSameDimToLvl(dstTp)) {
-        Value xs = genToCoordinatesBuffer(rewriter, loc, src);
-        rewriter.create<SortCooOp>(
-            loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
-            rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
-      } else {
-        // Gather the coordinates-arrays in the dst tensor storage order.
-        SmallVector<Value> xs(dstLvlRank);
-        const Level srcLvlRank = srcTp.getLvlRank();
-        for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
-          // FIXME: `toOrigDim` is deprecated
-          Dimension dim = toOrigDim(encSrc, srcLvl);
-          // FIXME: `toStoredDim` is deprecated
-          Level dstLvl = toStoredDim(encDst, dim);
-          xs[dstLvl] =
-              genToCoordinates(rewriter, loc, src, srcLvl, /*cooStart=*/0);
-        }
-        rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y},
-                                SparseTensorSortKind::HybridQuickSort);
+      // Builds the dstLvl -> srcLvl permutation maps.
+      SmallVector<AffineExpr> es(dstLvlRank);
+      const Level srcLvlRank = srcTp.getLvlRank();
+      for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
+        // FIXME: `toOrigDim` is deprecated
+        Dimension dim = toOrigDim(encSrc, srcLvl);
+        // FIXME: `toStoredDim` is deprecated
+        Level dstLvl = toStoredDim(encDst, dim);
+        es[dstLvl] = rewriter.getAffineDimExpr(srcLvl);
       }
+      auto xPerm = AffineMap::get(dstLvlRank, 0, es, rewriter.getContext());
+      assert(xPerm.isPermutation()); // must be a permutation.
+
+      Value xs = genToCoordinatesBuffer(rewriter, loc, src);
+      rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y}, xPerm,
+                                 rewriter.getIndexAttr(0),
+                                 SparseTensorSortKind::HybridQuickSort);
     }
 
     // For each element in the COO tensor, insert the element to the dst tensor.

>From 6bbb3ba6543139af10d9dc86b6728495c6e77b73 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Sep 2023 22:10:23 +0000
Subject: [PATCH 2/4] remove deadcode after cleanup

---
 .../Transforms/SparseBufferRewriting.cpp      | 235 ++++++++----------
 1 file changed, 103 insertions(+), 132 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 255c78f8b4eb61f..7011578a5afa0b7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -45,23 +45,20 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
 static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
 
-using FuncGeneratorType = function_ref<void(
-    OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, bool, uint32_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+                                            AffineMap, uint64_t, uint32_t)>;
 
 /// Constructs a function name with this format to facilitate quick sort:
 ///   <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
 ///   <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
                                          StringRef namePrefix, AffineMap xPerm,
-                                         uint64_t ny, bool isCoo,
-                                         ValueRange operands) {
+                                         uint64_t ny, ValueRange operands) {
   nameOstream << namePrefix << xPerm << "_"
               << getMemRefType(operands[xStartIdx]).getElementType();
+  nameOstream << "_coo_" << ny;
 
-  if (isCoo)
-    nameOstream << "_coo_" << ny;
-
-  uint64_t yBufferOffset = isCoo ? 1 : xPerm.getNumResults();
+  constexpr uint64_t yBufferOffset = 1;
   for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
     nameOstream << "_" << getMemRefType(v).getElementType();
 }
@@ -69,22 +66,19 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
 /// Looks up a function that is appropriate for the given operands being
 /// sorted, and creates such a function if it doesn't exist yet. The
 /// parameters `xPerm` and `ny` tell the number of x and y values provided
-/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
-/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+/// by the buffer in xStartIdx.
 //
 // All sorting function generators take (lo, hi, xs, ys) in `operands` as
 // parameters for the sorting functions. Other parameters, such as the recursive
 // call depth, are appended to the end of the parameter list as
 // "trailing parameters".
-static FlatSymbolRefAttr
-getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
-                         TypeRange resultTypes, StringRef namePrefix,
-                         AffineMap xPerm, uint64_t ny, bool isCoo,
-                         ValueRange operands, FuncGeneratorType createFunc,
-                         uint32_t nTrailingP = 0) {
+static FlatSymbolRefAttr getMangledSortHelperFunc(
+    OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
+    StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
+    FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
-  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, isCoo,
+  getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
                                operands.drop_back(nTrailingP));
 
   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +95,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
         loc, nameOstream.str(),
         FunctionType::get(context, operands.getTypes(), resultTypes));
     func.setPrivate();
-    createFunc(builder, module, func, xPerm, ny, isCoo, nTrailingP);
+    createFunc(builder, module, func, xPerm, ny, nTrailingP);
   }
 
   return result;
@@ -111,28 +105,20 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInXs(
     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
-    uint64_t ny, bool isCoo,
+    uint64_t ny,
     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
-  Value iOffset, jOffset;
-  if (isCoo) {
-    Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
-    iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
-    jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
-  }
+  Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
+  Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+  Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
   for (AffineExpr e : xPerm.getResults()) {
     unsigned k = e.cast<AffineDimExpr>().getPosition();
     scf::IfOp ifOp;
     Value i, j, buffer;
-    if (isCoo) {
-      Value ck = constantIndex(builder, loc, k);
-      i = builder.create<arith::AddIOp>(loc, ck, iOffset);
-      j = builder.create<arith::AddIOp>(loc, ck, jOffset);
-      buffer = args[xStartIdx];
-    } else {
-      i = args[0];
-      j = args[1];
-      buffer = args[xStartIdx + k];
-    }
+    Value ck = constantIndex(builder, loc, k);
+    i = builder.create<arith::AddIOp>(loc, ck, iOffset);
+    j = builder.create<arith::AddIOp>(loc, ck, jOffset);
+    buffer = args[xStartIdx];
+
     bodyBuilder(k, i, j, buffer);
   }
 }
@@ -141,11 +127,10 @@ static void forEachIJPairInXs(
 /// The code to process the value pairs is generated by `bodyBuilder`.
 static void forEachIJPairInAllBuffers(
     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
-    uint64_t ny, bool isCoo,
+    uint64_t ny,
     function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
 
-  // Create code for the first (xPerm + ny) buffers. When isCoo==true, these
-  // logical buffers are all from the xy buffer of the sort_coo operator.
+  // Create code for the first (xPerm + ny) buffers.
   SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
                                xPerm.getResults().end());
   for (unsigned y = 0; y < ny; y++) {
@@ -154,10 +139,9 @@ static void forEachIJPairInAllBuffers(
   AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
   assert(xyPerm.isPermutation());
 
-  forEachIJPairInXs(builder, loc, args, xyPerm, 0, isCoo, bodyBuilder);
-
-  uint64_t numHandledBuffers = isCoo ? 1 : xPerm.getNumResults() + ny;
+  forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
 
+  constexpr uint64_t numHandledBuffers = 1;
   // Create code for the remaining buffers.
   Value i = args[0];
   Value j = args[1];
@@ -179,7 +163,7 @@ static void forEachIJPairInAllBuffers(
 //     ...
 //     swap(yn[i], yn[j]);
 static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
-                       AffineMap xPerm, uint64_t ny, bool isCoo) {
+                       AffineMap xPerm, uint64_t ny) {
   auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
     Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
     Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -187,14 +171,14 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
     builder.create<memref::StoreOp>(loc, vi, buffer, j);
   };
 
-  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, isCoo, swapOnePair);
+  forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
 }
 
 /// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
 /// each pair is create via `compareBuilder`.
 static Value createInlinedCompareImplementation(
     OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
-    uint64_t ny, bool isCoo,
+    uint64_t ny,
     function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
         compareBuilder) {
   Value result;
@@ -213,7 +197,7 @@ static Value createInlinedCompareImplementation(
     }
   };
 
-  forEachIJPairInXs(builder, loc, args, xPerm, ny, isCoo, bodyBuilder);
+  forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
 
   builder.setInsertionPointAfterValue(result);
   return result;
@@ -264,13 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
 //       and so on ...
 static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
                                     ValueRange args, AffineMap xPerm,
-                                    uint64_t ny, bool isCoo,
-                                    uint32_t nTrailingP = 0) {
+                                    uint64_t ny, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
-                                            isCoo, createEqCompare);
+                                            createEqCompare);
 }
 
 /// Generates code to compare whether x[i] is less than x[j] and returns the
@@ -319,13 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
 //       and so on ...
 static Value createInlinedLessThan(OpBuilder &builder, Location loc,
                                    ValueRange args, AffineMap xPerm,
-                                   uint64_t ny, bool isCoo,
-                                   uint32_t nTrailingP = 0) {
+                                   uint64_t ny, uint32_t nTrailingP = 0) {
   // Compare functions don't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
   return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
-                                            isCoo, createLessThanCompare);
+                                            createLessThanCompare);
 }
 
 /// Creates a function to use a binary search to find the insertion point for
@@ -343,8 +325,7 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
 //
 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
                                    func::FuncOp func, AffineMap xPerm,
-                                   uint64_t ny, bool isCoo,
-                                   uint32_t nTrailingP = 0) {
+                                   uint64_t ny, uint32_t nTrailingP = 0) {
   // Binary search doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -382,11 +363,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 
   // Compare xs[p] < xs[mid].
   SmallVector<Value> compareOperands{p, mid};
-  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+  constexpr uint64_t numXBuffers = 1;
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
-  Value cond2 =
-      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+  Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
   // Update lo and hi for the WhileOp as follows:
   //   if (xs[p] < xs[mid]))
   //     hi = mid;
@@ -406,10 +386,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
 ///   while (xs[i] > xs[p]) i += step (step < 0)
 /// The routine returns i as well as a boolean value to indicate whether
 /// xs[i] == xs[p].
-static std::pair<Value, Value>
-createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
-               ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny,
-               bool isCoo, int step) {
+static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
+                                              ModuleOp module,
+                                              func::FuncOp func, ValueRange xs,
+                                              Value i, Value p, AffineMap xPerm,
+                                              uint64_t ny, int step) {
   Location loc = func.getLoc();
   scf::WhileOp whileOp =
       builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
@@ -427,8 +408,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
     compareOperands.push_back(before->getArgument(0));
   }
   compareOperands.append(xs.begin(), xs.end());
-  Value cond =
-      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
   Block *after =
@@ -443,7 +423,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   compareOperands[0] = i;
   compareOperands[1] = p;
   Value compareEq =
-      createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny, isCoo);
+      createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
 
   return std::make_pair(whileOp.getResult(0), compareEq);
 }
@@ -452,67 +432,63 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
 /// if compareFunc(data[b], data[a]) returns true. The new insertion point is
 /// right after the swap instructions.
 static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
-                                       AffineMap xPerm, uint64_t ny, bool isCoo,
+                                       AffineMap xPerm, uint64_t ny,
                                        SmallVectorImpl<Value> &swapOperands,
                                        SmallVectorImpl<Value> &compareOperands,
                                        Value a, Value b) {
   // Compare(data[b], data[a]).
   compareOperands[0] = b;
   compareOperands[1] = a;
-  Value cond =
-      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
   scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   swapOperands[0] = b;
   swapOperands[1] = a;
-  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny);
   return ifOp;
 }
 
 /// Creates code to insert the 3rd element to a list of two sorted elements.
 static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
-                            uint64_t ny, bool isCoo,
-                            SmallVectorImpl<Value> &swapOperands,
+                            uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                             SmallVectorImpl<Value> &compareOperands, Value v0,
                             Value v1, Value v2) {
-  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
-                                         swapOperands, compareOperands, v1, v2);
-  createCompareThenSwap(builder, loc, xPerm, ny, isCoo, swapOperands,
-                        compareOperands, v0, v1);
+  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
+                                         compareOperands, v1, v2);
+  createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
+                        v0, v1);
   builder.setInsertionPointAfter(ifOp);
 }
 
 /// Creates code to sort 3 elements.
 static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
-                        uint64_t ny, bool isCoo,
-                        SmallVectorImpl<Value> &swapOperands,
+                        uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                         SmallVectorImpl<Value> &compareOperands, Value v0,
                         Value v1, Value v2) {
   // Sort the first 2 elements.
-  scf::IfOp ifOp1 = createCompareThenSwap(
-      builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0, v1);
+  scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
+                                          compareOperands, v0, v1);
   builder.setInsertionPointAfter(ifOp1);
 
   // Insert the 3th element.
-  createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands,
-                  v0, v1, v2);
+  createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
+                  v1, v2);
 }
 
 /// Creates code to sort 5 elements.
 static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
-                        uint64_t ny, bool isCoo,
-                        SmallVectorImpl<Value> &swapOperands,
+                        uint64_t ny, SmallVectorImpl<Value> &swapOperands,
                         SmallVectorImpl<Value> &compareOperands, Value v0,
                         Value v1, Value v2, Value v3, Value v4) {
   // Sort the first 3 elements.
-  createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0,
-              v1, v2);
+  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
+              v2);
 
   auto insert4th = [&]() {
     scf::IfOp ifOp = createCompareThenSwap(
-        builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v2, v3);
-    createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands,
-                    compareOperands, v0, v1, v2);
+        builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
+    createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
+                    v1, v2);
     builder.setInsertionPointAfter(ifOp);
   };
 
@@ -520,8 +496,8 @@ static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
   insert4th();
 
   // Insert the 5th element.
-  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
-                                         swapOperands, compareOperands, v3, v4);
+  scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
+                                         compareOperands, v3, v4);
   insert4th();
   builder.setInsertionPointAfter(ifOp);
 }
@@ -532,10 +508,9 @@ static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
 /// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
 static void createChoosePivot(OpBuilder &builder, ModuleOp module,
                               func::FuncOp func, AffineMap xPerm, uint64_t ny,
-                              bool isCoo, Value lo, Value hi, Value mi,
-                              ValueRange args) {
+                              Value lo, Value hi, Value mi, ValueRange args) {
   SmallVector<Value> compareOperands{mi, lo};
-  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+  constexpr uint64_t numXBuffers = 1;
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
   SmallVector<Value> swapOperands{mi, lo};
@@ -551,8 +526,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
 
   // When len < 1000, choose pivot from median of 3 values.
   builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
-  createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
-              mi, hi);
+  createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
+              hi);
 
   // When len >= 1000, choose pivot from median of 5 values.
   builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
@@ -563,8 +538,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
   Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
   // Value b is the middle between [mi, hi].
   b = builder.create<arith::ShRUIOp>(loc, b, c1);
-  createSort5(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
-              a, mi, b, hi);
+  createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
+              b, hi);
 
   builder.setInsertionPointAfter(lenIf);
 }
@@ -601,7 +576,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
 // }
 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
-                                bool isCoo, uint32_t nTrailingP = 0) {
+                                uint32_t nTrailingP = 0) {
   // Quick sort partition doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -620,7 +595,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
 
   Value i = lo;
   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
-  createChoosePivot(builder, module, func, xPerm, ny, isCoo, i, j, p, args);
+  createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
   Value trueVal = constantI1(builder, loc, true); // The value for while (true)
   SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
   SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
@@ -642,14 +617,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   j = after->getArgument(1);
   p = after->getArgument(2);
 
-  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+  constexpr uint64_t numXBuffers = 1;
   auto [iresult, iCompareEq] =
       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
-                     i, p, xPerm, ny, isCoo, 1);
+                     i, p, xPerm, ny, 1);
   i = iresult;
   auto [jresult, jCompareEq] =
       createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
-                     j, p, xPerm, ny, isCoo, -1);
+                     j, p, xPerm, ny, -1);
   j = jresult;
 
   // If i < j:
@@ -659,7 +634,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   SmallVector<Value> swapOperands{i, j};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny);
   // If the pivot is moved, update p with the new pivot.
   Value icond =
       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -752,7 +727,7 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
 //
 static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
-                                bool isCoo, uint32_t nTrailingP) {
+                                uint32_t nTrailingP) {
   // The value n is passed in as a trailing parameter.
   assert(nTrailingP == 1);
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -782,7 +757,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
   Value c1 = constantIndex(builder, loc, 1);
   SmallVector<Value> compareOperands{start, start};
-  uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+  constexpr uint64_t numXBuffers = 1;
   compareOperands.append(args.begin() + xStartIdx,
                          args.begin() + xStartIdx + numXBuffers);
 
@@ -808,7 +783,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
     compareOperands[0] = lChildIdx;
     compareOperands[1] = rChildIdx;
     Value cond2 =
-        createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+        createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
     scf::IfOp if2 =
         builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
     builder.setInsertionPointToStart(&if2.getThenRegion().front());
@@ -839,8 +814,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   childIdx = before->getArgument(2);
   compareOperands[0] = start;
   compareOperands[1] = childIdx;
-  Value cond =
-      createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+  Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
   builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
   // The after-region of the WhileOp.
@@ -850,7 +824,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
   childIdx = after->getArgument(2);
   SmallVector<Value> swapOperands{start, childIdx};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny);
   start = childIdx;
   Value cond2 =
       builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
@@ -884,7 +858,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
 // }
 static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
                                func::FuncOp func, AffineMap xPerm, uint64_t ny,
-                               bool isCoo, uint32_t nTrailingP) {
+                               uint32_t nTrailingP) {
   // Heap sort function doesn't have trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -911,7 +885,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
   shiftDownOperands.append(args.begin() + xStartIdx, args.end());
   shiftDownOperands.push_back(n);
   FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
-      builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, isCoo,
+      builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
       shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
   builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
                                shiftDownOperands);
@@ -926,7 +900,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
   loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
   SmallVector<Value> swapOperands{lo, loplm1};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+  createSwap(builder, loc, swapOperands, xPerm, ny);
   shiftDownOperands[1] = lo;
   shiftDownOperands[shiftDownOperands.size() - 1] =
       builder.create<arith::SubIOp>(loc, l, c1);
@@ -942,7 +916,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
 /// the bigger partition to be processed by the enclosed while-loop.
 static std::pair<Value, Value>
 createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
-                ValueRange args, AffineMap xPerm, uint64_t ny, bool isCoo,
+                ValueRange args, AffineMap xPerm, uint64_t ny,
                 uint32_t nTrailingP) {
   MLIRContext *context = module.getContext();
   Location loc = func.getLoc();
@@ -952,7 +926,7 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
 
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
-      ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
+      ny, args.drop_back(nTrailingP), createPartitionFunc);
   Value p = builder
                 .create<func::CallOp>(loc, partitionFunc,
                                       TypeRange{IndexType::get(context)},
@@ -1023,7 +997,7 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
 // }
 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
                                  func::FuncOp func, AffineMap xPerm,
-                                 uint64_t ny, bool isCoo, uint32_t nTrailingP) {
+                                 uint64_t ny, uint32_t nTrailingP) {
   // Stable sort function doesn't use trailing parameters.
   (void)nTrailingP;
   assert(nTrailingP == 0);
@@ -1049,7 +1023,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   operands.append(args.begin() + xStartIdx, args.end());
   FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
       builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
-      xPerm, ny, isCoo, operands, createBinarySearchFunc);
+      xPerm, ny, operands, createBinarySearchFunc);
   Value p = builder
                 .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
                                       operands)
@@ -1059,7 +1033,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   operands[0] = operands[1] = i;
   SmallVector<Value> d;
   forEachIJPairInAllBuffers(
-      builder, loc, operands, xPerm, ny, isCoo,
+      builder, loc, operands, xPerm, ny,
       [&](uint64_t unused, Value i, Value unused2, Value buffer) {
         d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
       });
@@ -1075,7 +1049,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   operands[1] = imj;
   operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
   forEachIJPairInAllBuffers(
-      builder, loc, operands, xPerm, ny, isCoo,
+      builder, loc, operands, xPerm, ny,
       [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
         Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
         builder.create<memref::StoreOp>(loc, t, buffer, imj);
@@ -1085,7 +1059,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   builder.setInsertionPointAfter(forOpJ);
   operands[0] = operands[1] = p;
   forEachIJPairInAllBuffers(
-      builder, loc, operands, xPerm, ny, isCoo,
+      builder, loc, operands, xPerm, ny,
       [&](uint64_t k, Value p, Value usused, Value buffer) {
         builder.create<memref::StoreOp>(loc, d[k], buffer, p);
       });
@@ -1138,7 +1112,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
 //
 static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, AffineMap xPerm, uint64_t ny,
-                                bool isCoo, uint32_t nTrailingP) {
+                                uint32_t nTrailingP) {
   assert(nTrailingP == 1 || nTrailingP == 0);
   bool isHybrid = (nTrailingP == 1);
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -1187,7 +1161,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     // When len <= limit.
     builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
     FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
-        builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, isCoo,
+        builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
         ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
     builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
                                  ValueRange(args).drop_back(nTrailingP));
@@ -1207,7 +1181,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     // When depth exceeds limit.
     builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
     FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
-        builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, isCoo,
+        builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
         ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
     builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
                                  ValueRange(args).drop_back(nTrailingP));
@@ -1216,8 +1190,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     // When depth doesn't exceed limit.
     builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
     args.back() = depthLimit;
-    std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
-                                       isCoo, nTrailingP);
+    std::tie(lo, hi) =
+        createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
     builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
 
     builder.setInsertionPointAfter(depthIf);
@@ -1229,8 +1203,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
     lo = lenIf.getResult(0);
     hi = lenIf.getResult(1);
   } else {
-    std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
-                                       isCoo, nTrailingP);
+    std::tie(lo, hi) =
+        createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
   }
 
   // New [lo, hi) for the next while-loop iteration.
@@ -1244,8 +1218,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
 /// Implements the rewriting for operator sort and sort_coo.
 template <typename OpTy>
 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
-                                    uint64_t ny, bool isCoo,
-                                    PatternRewriter &rewriter) {
+                                    uint64_t ny, PatternRewriter &rewriter) {
   Location loc = op.getLoc();
   SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
 
@@ -1298,9 +1271,9 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
     break;
   }
 
-  FlatSymbolRefAttr func = getMangledSortHelperFunc(
-      rewriter, insertPoint, TypeRange(), funcName, xPerm, ny, isCoo, operands,
-      funcGenerator, nTrailingP);
+  FlatSymbolRefAttr func =
+      getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
+                               xPerm, ny, operands, funcGenerator, nTrailingP);
   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
   return success();
 }
@@ -1310,7 +1283,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
 //===---------------------------------------------------------------------===//
 
 namespace {
-
 /// Sparse rewriting rule for the push_back operator.
 struct PushBackRewriter : OpRewritePattern<PushBackOp> {
 public:
@@ -1440,8 +1412,7 @@ struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
     if (auto nyAttr = op.getNyAttr())
       ny = nyAttr.getInt();
 
-    return matchAndRewriteSortOp(op, xys, xPerm, ny,
-                                 /*isCoo=*/true, rewriter);
+    return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
   }
 };
 

>From b21bc7ed97444ff95f8607895b14bb1198db3ad3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Sep 2023 23:13:46 +0000
Subject: [PATCH 3/4] update test cases.

---
 .../Transforms/SparseBufferRewriting.cpp      |  14 +-
 .../SparseTensor/CPU/sparse_rewrite_sort.mlir | 187 ------------------
 .../CPU/sparse_rewrite_sort_coo.mlir          |   8 +-
 3 files changed, 11 insertions(+), 198 deletions(-)
 delete mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 7011578a5afa0b7..101bd165cc598b2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -110,14 +110,12 @@ static void forEachIJPairInXs(
   Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
   Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
   Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
-  for (AffineExpr e : xPerm.getResults()) {
-    unsigned k = e.cast<AffineDimExpr>().getPosition();
-    scf::IfOp ifOp;
-    Value i, j, buffer;
-    Value ck = constantIndex(builder, loc, k);
-    i = builder.create<arith::AddIOp>(loc, ck, iOffset);
-    j = builder.create<arith::AddIOp>(loc, ck, jOffset);
-    buffer = args[xStartIdx];
+  for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
+    unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
+    Value ak = constantIndex(builder, loc, actualK);
+    Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
+    Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
+    Value buffer = args[xStartIdx];
 
     bodyBuilder(k, i, j, buffer);
   }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
deleted file mode 100644
index 9e8ecad9cf282a2..000000000000000
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ /dev/null
@@ -1,187 +0,0 @@
-//--------------------------------------------------------------------------------------------------
-// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
-//
-// Set-up that's shared across all tests in this directory. In principle, this
-// config could be moved to lit.local.cfg. However, there are downstream users that
-//  do not use these LIT config files. Hence why this is kept inline.
-//
-// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
-// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
-// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
-// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
-// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
-// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
-// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
-//
-// DEFINE: %{env} =
-//--------------------------------------------------------------------------------------------------
-
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
-// RUN: %{compile} | %{run} | FileCheck %s
-//
-// Do the same run, but now with vectorization.
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
-// RUN: %{compile} | %{run} | FileCheck %s
-//
-// Do the same run, but now with  VLA vectorization.
-// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-
-module {
-  func.func private @printMemref1dI32(%ptr : memref<?xi32>) attributes { llvm.emit_c_interface }
-
-  // Stores 5 values to the memref buffer.
-  func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
-    %v3: i32, %v4: i32) -> () {
-    %i0 = arith.constant 0 : index
-    %i1 = arith.constant 1 : index
-    %i2 = arith.constant 2 : index
-    %i3 = arith.constant 3 : index
-    %i4 = arith.constant 4 : index
-    memref.store %v0, %b[%i0] : memref<?xi32>
-    memref.store %v1, %b[%i1] : memref<?xi32>
-    memref.store %v2, %b[%i2] : memref<?xi32>
-    memref.store %v3, %b[%i3] : memref<?xi32>
-    memref.store %v4, %b[%i4] : memref<?xi32>
-    return
-  }
-
-  // The main driver.
-  func.func @entry() {
-    %c0 = arith.constant 0 : i32
-    %c1 = arith.constant 1 : i32
-    %c2 = arith.constant 2 : i32
-    %c3 = arith.constant 3 : i32
-    %c4 = arith.constant 4 : i32
-    %c5 = arith.constant 5 : i32
-    %c6 = arith.constant 6 : i32
-    %c7 = arith.constant 7 : i32
-    %c8 = arith.constant 8 : i32
-    %c9 = arith.constant 9 : i32
-    %c10 = arith.constant 10 : i32
-    %c100 = arith.constant 100 : i32
-
-    %i0 = arith.constant 0 : index
-    %i4 = arith.constant 4 : index
-    %i5 = arith.constant 5 : index
-
-    // Prepare a buffer.
-    %x0s = memref.alloc() : memref<5xi32>
-    %x0 = memref.cast %x0s : memref<5xi32> to memref<?xi32>
-    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-
-    // Sort 0 elements.
-    // Quick sort.
-    // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort quick_sort %i0, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    // Stable sort.
-    // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    // Heap sort.
-    // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    // Hybrid sort.
-    // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-
-    // Sort the first 4 elements, with the last valid value untouched.
-    // Quick sort.
-    // CHECK: [0,  2,  5, 10,  1]
-    sparse_tensor.sort quick_sort %i4, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    // Stable sort.
-    // CHECK: [0,  2,  5,  10,  1]
-    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    // Heap sort.
-    // CHECK: [0,  2,  5,  10,  1]
-    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    // Hybrid sort.
-    // CHECK: [0,  2,  5, 10,  1]
-    sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-
-    // Prepare more buffers of different dimensions.
-    %x1s = memref.alloc() : memref<10xi32>
-    %x1 = memref.cast %x1s : memref<10xi32> to memref<?xi32>
-    %x2s = memref.alloc() : memref<6xi32>
-    %x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
-    %y0s = memref.alloc() : memref<7xi32>
-    %y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
-
-    // Sort "parallel arrays".
-    // CHECK: [1,  1,  2,  5,  10]
-    // CHECK: [3,  3,  1,  10,  1
-    // CHECK: [9,  9,  4,  7,  2
-    // CHECK: [7,  8,  10,  9,  6
-    call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0
-      : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
-    // Stable sort.
-    // CHECK: [1,  1,  2,  5,  10]
-    // CHECK: [3,  3,  1,  10,  1
-    // CHECK: [9,  9,  4,  7,  2
-    // CHECK: [8,  7,  10,  9,  6
-    call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0
-      : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
-    // Heap sort.
-    // CHECK: [1,  1,  2,  5,  10]
-    // CHECK: [3,  3,  1,  10,  1
-    // CHECK: [9,  9,  4,  7,  2
-    // CHECK: [7,  8,  10,  9,  6
-    call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
-      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0
-      : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
-    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
-    call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
-
-    // Release the buffers.
-    memref.dealloc %x0 : memref<?xi32>
-    memref.dealloc %x1 : memref<?xi32>
-    memref.dealloc %x2 : memref<?xi32>
-    memref.dealloc %y0 : memref<?xi32>
-    return
-  }
-}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index ca5dd00d02aff1e..0594b311184f4d5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -28,6 +28,8 @@
 // Do the same run, but now with  VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
+#ID_MAP = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
 module {
   // Stores 5 values to the memref buffer.
   func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
@@ -109,7 +111,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+    sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x0v : vector<5xi32>
@@ -137,7 +139,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x0v2 : vector<5xi32>
@@ -165,7 +167,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+    sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x0v3 : vector<5xi32>

>From bb146d5ead3f29e832b0bf3d3c94046089727052 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 19 Sep 2023 00:06:40 +0000
Subject: [PATCH 4/4] fix check tests

---
 .../SparseTensor/IR/SparseTensorOps.td        |  24 ++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  17 +--
 .../Transforms/SparseBufferRewriting.cpp      |   9 +-
 .../SparseTensor/buffer_rewriting.mlir        | 105 ++++--------------
 mlir/test/Dialect/SparseTensor/codegen.mlir   |   6 +-
 .../SparseTensor/convert_sparse2sparse.mlir   |   2 +-
 mlir/test/Dialect/SparseTensor/invalid.mlir   |  47 ++++----
 mlir/test/Dialect/SparseTensor/roundtrip.mlir |  64 ++---------
 .../SparseTensor/sparse_matmul_codegen.mlir   |   2 +-
 .../CPU/sparse_rewrite_sort_coo.mlir          |  51 ++++-----
 10 files changed, 109 insertions(+), 218 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index d83d1ba03feb848..59815fc755ee5f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -765,27 +765,29 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
 def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
-               AffineMapAttr:$nx, OptionalAttr<IndexAttr>:$ny,
+               AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
                SparseTensorSortKindAttr:$algorithm)>  {
   let summary = "Sorts the arrays in xs and ys lexicographically on the "
                 "integral values found in the xs list";
   let description = [{
-    Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
-    `xs` values and some `ys` values are put in the linear buffer `xy`. The
-    optional index attribute `nx` provides the number of `xs` values in `xy`.
-    When `nx` is not explicitly specified, its value is 1. The optional index
-    attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
-    explicitly specified, its value is 0. This instruction supports a more
-    efficient way to store the COO definition in sparse tensor type.
-
-    The buffer xy should have a dimension not less than n * (nx + ny) while the
+    Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
+    that are put in a single linear buffer `xy`.
+    The affine map attribute `perm_map` specifies the permutation to be applied on
+    the `xs` before comparison, the rank of the permutation map
+    also specifies the number of `xs` values in `xy`.
+    The optional index attribute `ny` provides the number of `ys` values in `xy`.
+    When `ny` is not explicitly specified, its value is 0.
+    This instruction supports a more efficient way to store the COO definition
+    in sparse tensor type.
+
+    The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
     buffers in `ys` should have a dimension not less than `n`. The behavior of
     the operator is undefined if this condition is not met.
 
     Example:
 
     ```mlir
-    sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
       : memref<?xindex>
     ```
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3cd0847bdf73765..9675a61109477b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1354,6 +1354,14 @@ LogicalResult SelectOp::verify() {
 }
 
 LogicalResult SortCooOp::verify() {
+  AffineMap xPerm = getPermMap();
+  uint64_t nx = xPerm.getNumDims();
+  if (nx < 1)
+    emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+
+  if (!xPerm.isPermutation())
+    emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+
   std::optional<int64_t> cn = getConstantIntValue(getN());
   // We can't check the size of the buffers when n or buffer dimensions aren't
   // compile-time constants.
@@ -1361,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
     return success();
 
   uint64_t n = cn.value();
-  uint64_t nx = 1;
-  if (auto nxAttr = getNxAttr()) {
-    nx = nxAttr.getAffineMap().getNumResults();
-    if (nx < 1)
-      emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
-  }
   uint64_t ny = 0;
   if (auto nyAttr = getNyAttr()) {
     ny = nyAttr.getInt();
@@ -1381,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
       emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
   };
 
-  checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+  checkDim(getXy(), n * (nx + ny),
+           "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
 
   for (Value opnd : getYs()) {
     checkDim(opnd, n, "Expected dimension(y) >= n");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 101bd165cc598b2..3181395a474cfec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -54,8 +54,11 @@ using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
 static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
                                          StringRef namePrefix, AffineMap xPerm,
                                          uint64_t ny, ValueRange operands) {
-  nameOstream << namePrefix << xPerm << "_"
-              << getMemRefType(operands[xStartIdx]).getElementType();
+  nameOstream << namePrefix;
+  for (auto res : xPerm.getResults())
+    nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
+
+  nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
   nameOstream << "_coo_" << ny;
 
   constexpr uint64_t yBufferOffset = 1;
@@ -1405,7 +1408,7 @@ struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
     xys.push_back(op.getXy());
     xys.append(op.getYs().begin(), op.getYs().end());
 
-    auto xPerm = op.getNx();
+    auto xPerm = op.getPermMap();
     uint64_t ny = 0;
     if (auto nyAttr = op.getNyAttr())
       ny = nyAttr.getInt();
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 0036bd5c3310b97..c96a55aa1e8b2f3 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -75,123 +75,64 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 
 // -----
 
-// CHECK-LABEL:   func.func private @_sparse_partition_1_i8_f32_index
-// CHECK-LABEL:   func.func private @_sparse_qsort_1_i8_f32_index
-// CHECK-LABEL:   func.func @sparse_sort_1d2v_quick
-func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
-   -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
-  sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
-  return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting function now. We have integration test
-// to verify correctness of the generated code.
-//
-// CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL:   func.func @sparse_sort_3d_quick
-func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
-  sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting function now. We have integration test
-// to verify correctness of the generated code.
-//
-// CHECK-DAG:     func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-DAG:     func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
-// CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
-// CHECK-LABEL:   func.func @sparse_sort_3d_hybrid
-func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting functions. We have integration test to
-// verify correctness of the generated code.
-//
-// CHECK-DAG:     func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL:   func.func @sparse_sort_3d_stable
-func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
-  sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
 
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
-// CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL:   func.func @sparse_sort_3d_heap
-func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
-  sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting functions. We have integration test to
-// verify correctness of the generated code.
-//
-// CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG:     func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_quick
 func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
 // -----
 
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
+
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-DAG:     func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
-// CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
+// CHECK-DAG:     func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG:     func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_hybrid
 func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
 // -----
 
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
+
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_stable
 func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
 // -----
 
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
+
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
-// CHECK-DAG:     func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
-// CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG:     func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_heap
 func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index f1317f23d656848..ea11a98b76ec639 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -436,7 +436,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
+//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
 //       CHECK:   %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
 //       CHECK:   %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -484,7 +484,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //       CHECK:     %[[A11:.*]] = arith.constant 0.000000e+00 : f64
 //       CHECK:     %[[A12:.*]] = arith.constant 1 : index
 //       CHECK:     %[[A13:.*]] = arith.constant 0 : index
-//       CHECK:     sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
+//       CHECK:     sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK:     %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
 //       CHECK:       %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -712,7 +712,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
 //       CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
 //       CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
 //       CHECK: scf.if %[[A34]] {
-//       CHECK:   sparse_tensor.sort_coo  hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {nx = 2 : index, ny = 0 : index} : memref<?xindex> jointly memref<?xf32>
+//       CHECK:   sparse_tensor.sort_coo  hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
 //       CHECK: }
 //       CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
 //       CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]]  crd_mem_sz at 0 with %[[A11]]
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index b3eb50f1755dace..54cdfc690952d9a 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -178,7 +178,7 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
 //       CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor<?x?x?xf32, #{{.*}}>>
 //       CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xf32>
 //       CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xindex>
-//       CHECK-RWT: sparse_tensor.sort_coo  hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index}
+//       CHECK-RWT: sparse_tensor.sort_coo  hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map}
 //       CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]]
 //       CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]])
 //       CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor<?x?x?xf32, #{{.*}}>>):
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 71e6eebb30261c8..c0e813dcde7c57e 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -790,60 +790,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
   return
 }
 
-// -----
-
-// TODO: a test case with empty xs doesn't work due to some parser issues.
-
-func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
-  // expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref<?xf32>
-}
-
-// -----
-
-func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
-  %i20 = arith.constant 20 : index
-  // expected-error at +1 {{xs and ys need to have a dimension >= n: 10 < 20}}
-  sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex>
-  return
-}
 
 // -----
 
-func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
-  // expected-error at +1 {{mismatch xs element types}}
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
-  return
-}
-
-// -----
+#MAP = affine_map<(i,j) -> (i,j)>
 
 func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
   // expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
-  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref<?xf32>
+  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref<?xf32>
   return
 }
 
 // -----
 
+#MAP = affine_map<(i,j) -> (i,j)>
+
 func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
   %i20 = arith.constant 20 : index
-  // expected-error at +1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
-  sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
+  // expected-error at +1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}}
+  sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex>
   return
 }
 
 // -----
 
+#MAP = affine_map<(i,j) -> (i,j)>
+
 func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
   %i20 = arith.constant 20 : index
   // expected-error at +1 {{Expected dimension(y) >= n got 10 < 20}}
-  sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+  sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
   return
 }
 
 // -----
 
+#NON_PERM_MAP = affine_map<(i,j) -> (i,i)>
+
+func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
+  // expected-error at +1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}}
+  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref<?xindex>
+  return %arg1 : memref<?xindex>
+}
+
+// -----
+
 #CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
 
 func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d1262cb7aea02df..d252fa559a1543f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -612,79 +612,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
 
 // -----
 
-// CHECK-LABEL: func @sparse_sort_1d0v(
-//  CHECK-SAME: %[[A:.*]]: index,
-//  CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref<?xindex>
-//       CHECK: return %[[B]]
-func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref<?xindex>
-  return %arg1 : memref<?xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_sort_1d2v(
-//  CHECK-SAME: %[[A:.*]]: index,
-//  CHECK-SAME: %[[B:.*]]: memref<20xindex>,
-//  CHECK-SAME: %[[C:.*]]: memref<10xindex>,
-//  CHECK-SAME: %[[D:.*]]: memref<?xf32>)
-//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
-//       CHECK: return %[[B]], %[[C]], %[[D]]
-func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
-  return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_sort_2d1v(
-//  CHECK-SAME: %[[A:.*]]: index,
-//  CHECK-SAME: %[[B:.*]]: memref<10xi8>,
-//  CHECK-SAME: %[[C:.*]]: memref<20xi8>,
-//  CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
-//       CHECK: return %[[B]], %[[C]], %[[D]]
-func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
-  return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_sort_stable(
-//  CHECK-SAME: %[[A:.*]]: index,
-//  CHECK-SAME: %[[B:.*]]: memref<10xi8>,
-//  CHECK-SAME: %[[C:.*]]: memref<20xi8>,
-//  CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-//       CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
-//       CHECK: return %[[B]], %[[C]], %[[D]]
-func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
-  sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
-  return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
-}
-
-// -----
+#ID_MAP = affine_map<(i,j) -> (i,j)>
 
 // CHECK-LABEL: func @sparse_sort_coo(
 //  CHECK-SAME: %[[A:.*]]: index,
 //  CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref<?xindex>
+//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref<?xindex>
 //       CHECK: return %[[B]]
 func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
-  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref<?xindex>
+  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xindex>
   return %arg1 : memref<?xindex>
 }
 
 // -----
 
+#ID_MAP = affine_map<(i,j) -> (i,j)>
+
 // CHECK-LABEL: func @sparse_sort_coo_stable(
 //  CHECK-SAME: %[[A:.*]]: index,
 //  CHECK-SAME: %[[B:.*]]: memref<?xi64>,
 //  CHECK-SAME: %[[C:.*]]: memref<?xf32>)
-//       CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index}
+//       CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}}
 //       CHECK: return %[[B]], %[[C]]
 func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
-  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
+  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
   return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index b31ac3ef3a254ad..5c308dc3c56234b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -116,7 +116,7 @@
 // CHECK:               } {"Emitted from" = "linalg.generic"}
 // CHECK:               scf.yield %[[VAL_64:.*]] : index
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             sparse_tensor.sort  hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref<?xindex>
+// CHECK:             sparse_tensor.sort_coo  hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]]
 // CHECK:             %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:               %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
 // CHECK:               %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 0594b311184f4d5..394b9a8448b5438 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -28,7 +28,7 @@
 // Do the same run, but now with  VLA vectorization.
 // RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
 
-#ID_MAP = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#ID_MAP = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
 
 module {
   // Stores 5 values to the memref buffer.
@@ -96,11 +96,11 @@ module {
     %y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
 
     // Sort "parallel arrays".
-    // CHECK: ( 1, 1, 3, 3, 10 )
-    // CHECK: ( 2, 10, 1, 1, 5 )
-    // CHECK: ( 4, 2, 9, 9, 7 )
-    // CHECK: ( 10, 6, 7, 8, 9 )
-    // CHECK: ( 7, 5, 7, 4, 9 )
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 7, 8, 10, 9, 6 )
+    // CHECK: ( 7, 4, 7, 9, 5 )
     call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -111,24 +111,25 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
+    sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
-    %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
-    vector.print %x0v : vector<5xi32>
+    // Dumps memory in the same order as the perm_map such that the output is ordered.
     %x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x1v : vector<5xi32>
     %x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x2v : vector<5xi32>
+    %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x0v : vector<5xi32>
     %y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %y0v : vector<5xi32>
     %y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %y1v : vector<5xi32>
     // Stable sort.
-    // CHECK: ( 1, 1, 3, 3, 10 )
-    // CHECK: ( 2, 10, 1, 1, 5 )
-    // CHECK: ( 4, 2, 9, 9, 7 )
-    // CHECK: ( 10, 6, 8, 7, 9 )
-    // CHECK: ( 7, 5, 4, 7, 9 )
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 8, 7, 10, 9, 6 )
+    // CHECK: ( 4, 7, 7, 9, 5 )
     call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -139,24 +140,24 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
+    sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
-    %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
-    vector.print %x0v2 : vector<5xi32>
     %x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x1v2 : vector<5xi32>
     %x2v2 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x2v2 : vector<5xi32>
+    %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x0v2 : vector<5xi32>
     %y0v2 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %y0v2 : vector<5xi32>
     %y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %y1v2 : vector<5xi32>
     // Heap sort.
-    // CHECK: ( 1, 1, 3, 3, 10 )
-    // CHECK: ( 2, 10, 1, 1, 5 )
-    // CHECK: ( 4, 2, 9, 9, 7 )
-    // CHECK: ( 10, 6, 8, 7, 9 )
-    // CHECK: ( 7, 5, 4, 7, 9 )
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 7, 8, 10, 9, 6 )
+    // CHECK: ( 7, 4, 7, 9, 5 )
     call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -167,14 +168,14 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
+    sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
-    %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
-    vector.print %x0v3 : vector<5xi32>
     %x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x1v3 : vector<5xi32>
     %x2v3 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x2v3 : vector<5xi32>
+    %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x0v3 : vector<5xi32>
     %y0v3 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %y0v3 : vector<5xi32>
     %y1v3 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>



More information about the Mlir-commits mailing list