[Mlir-commits] [mlir] [mlir][sparse] renaming sparse_tensor.sort_coo to sparse_tensor.sort (PR #68161)

Peiming Liu llvmlistbot at llvm.org
Tue Oct 3 14:40:20 PDT 2023


https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/68161

Rationale: the operation does not always sort COO tensors (also used for sparse_tensor.compress for example).

>From 65a55125c41b41b4c3f8d445b5fe1e5551f587cb Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 3 Oct 2023 21:37:23 +0000
Subject: [PATCH] [mlir][sparse] renaming sparse_tensor.sort_coo to
 sparse_tensor.sort

---
 .../mlir/Dialect/SparseTensor/IR/SparseTensorOps.td  | 12 +++---------
 .../Dialect/SparseTensor/IR/SparseTensorDialect.cpp  |  2 +-
 .../Transforms/SparseBufferRewriting.cpp             |  8 ++++----
 .../SparseTensor/Transforms/SparseTensorCodegen.cpp  |  4 ++--
 .../SparseTensor/Transforms/SparseTensorPasses.cpp   |  2 +-
 .../Transforms/SparseTensorRewriting.cpp             |  2 +-
 mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir |  8 ++++----
 mlir/test/Dialect/SparseTensor/codegen.mlir          |  6 +++---
 .../Dialect/SparseTensor/convert_sparse2sparse.mlir  |  2 +-
 mlir/test/Dialect/SparseTensor/invalid.mlir          |  8 ++++----
 mlir/test/Dialect/SparseTensor/roundtrip.mlir        |  8 ++++----
 .../Dialect/SparseTensor/sparse_matmul_codegen.mlir  |  2 +-
 .../SparseTensor/CPU/sparse_rewrite_sort_coo.mlir    |  6 +++---
 13 files changed, 32 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index c566f674f7d4011..7ea5ca23f122a8a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -762,7 +762,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
 // Sparse Tensor Sorting Operations.
 //===----------------------------------------------------------------------===//
 
-def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
+def SparseTensor_SortOp : SparseTensor_Op<"sort">,
     Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
                Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
                AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
@@ -770,7 +770,7 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
   let summary = "Sorts the arrays in xs and ys lexicographically on the "
                 "integral values found in the xs list";
   let description = [{
-    Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
+    Sparse_tensor.sort sort the `xs` values along with some `ys` values
     that are put in a single linear buffer `xy`.
     The affine map attribute `perm_map` specifies the permutation to be applied on
     the `xs` before comparison, the rank of the permutation map
@@ -787,15 +787,9 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
     Example:
 
     ```mlir
-    sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
+    sparse_tensor.sort insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
       : memref<?xindex>
     ```
-
-    ```mlir
-    sparse_tensor.sort hybrid_quick_sort %n, %xy jointly %y1
-      { nx = 2 : index, ny = 2 : index}
-      : memref<?xi64> jointly memref<?xf32>
-    ```
   }];
 
   let assemblyFormat = "$algorithm $n"
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3897e1b9ea3597c..af262cef6d1d5e3 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1363,7 +1363,7 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
-LogicalResult SortCooOp::verify() {
+LogicalResult SortOp::verify() {
   AffineMap xPerm = getPermMap();
   uint64_t nx = xPerm.getNumDims();
   if (nx < 1)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 3181395a474cfec..ced7983af324c1f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -1398,11 +1398,11 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
 };
 
 /// Sparse rewriting rule for the sort_coo operator.
-struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
+struct SortRewriter : public OpRewritePattern<SortOp> {
 public:
-  using OpRewritePattern<SortCooOp>::OpRewritePattern;
+  using OpRewritePattern<SortOp>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(SortCooOp op,
+  LogicalResult matchAndRewrite(SortOp op,
                                 PatternRewriter &rewriter) const override {
     SmallVector<Value> xys;
     xys.push_back(op.getXy());
@@ -1427,5 +1427,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
                                          bool enableBufferInitialization) {
   patterns.add<PushBackRewriter>(patterns.getContext(),
                                  enableBufferInitialization);
-  patterns.add<SortCooRewriter>(patterns.getContext());
+  patterns.add<SortRewriter>(patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index f02276fba0d526b..d3f316348bd39c1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -931,7 +931,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     // If the innermost level is ordered, we need to sort the coordinates
     // in the "added" array prior to applying the compression.
     if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
-      rewriter.create<SortCooOp>(
+      rewriter.create<SortOp>(
           loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
           rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
     // While performing the insertions, we also need to reset the elements
@@ -1531,7 +1531,7 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
           rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
       auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
-      rewriter.create<SortCooOp>(loc, nse, xs, ValueRange{ys}, xPerm,
+      rewriter.create<SortOp>(loc, nse, xs, ValueRange{ys}, xPerm,
                                  rewriter.getIndexAttr(0),
                                  SparseTensorSortKind::HybridQuickSort);
       rewriter.setInsertionPointAfter(ifOp);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index 7d2f0c7f139cda5..f50d3d4606554a1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -207,7 +207,7 @@ struct SparseTensorCodegenPass
     ConversionTarget target(*ctx);
     // Most ops in the sparse dialect must go!
     target.addIllegalDialect<SparseTensorDialect>();
-    target.addLegalOp<SortCooOp>();
+    target.addLegalOp<SortOp>();
     target.addLegalOp<PushBackOp>();
     // Storage specifier outlives sparse tensor pipeline.
     target.addLegalOp<GetStorageSpecifierOp>();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 3b5b224434bead1..c5939920ccbc45f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1220,7 +1220,7 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
       assert(xPerm.isPermutation()); // must be a permutation.
 
       Value xs = genToCoordinatesBuffer(rewriter, loc, src);
-      rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y}, xPerm,
+      rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y}, xPerm,
                                  rewriter.getIndexAttr(0),
                                  SparseTensorSortKind::HybridQuickSort);
     }
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index c96a55aa1e8b2f3..cafb431b7530638 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -84,7 +84,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 // CHECK-DAG:     func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_quick
 func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
@@ -103,7 +103,7 @@ func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // CHECK-DAG:     func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_hybrid
 func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
@@ -118,7 +118,7 @@ func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // CHECK-DAG:     func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_stable
 func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
@@ -133,6 +133,6 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
 // CHECK-DAG:     func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
 // CHECK-LABEL:   func.func @sparse_sort_coo_heap
 func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
-  sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  sparse_tensor.sort heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 69a9c274a861ce1..9af5a46deadc9d2 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -423,7 +423,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
 //   CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
 //   CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
 //   CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
 //       CHECK:   %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
 //       CHECK:   %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -471,7 +471,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
 //       CHECK:     %[[A11:.*]] = arith.constant 0.000000e+00 : f64
 //       CHECK:     %[[A12:.*]] = arith.constant 1 : index
 //       CHECK:     %[[A13:.*]] = arith.constant 0 : index
-//       CHECK:     sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
+//       CHECK:     sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]]
 //       CHECK:     %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
 //       CHECK:       %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
 //       CHECK:       %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -699,7 +699,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
 //       CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
 //       CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
 //       CHECK: scf.if %[[A34]] {
-//       CHECK:   sparse_tensor.sort_coo  hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
+//       CHECK:   sparse_tensor.sort  hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
 //       CHECK: }
 //       CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
 //       CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]]  crd_mem_sz at 0 with %[[A11]]
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index 2a2619daf493654..c373fd23bbef492 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -177,7 +177,7 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
 //       CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor<?x?x?xf32, #{{.*}}>>
 //       CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xf32>
 //       CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xindex>
-//       CHECK-RWT: sparse_tensor.sort_coo  hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map}
+//       CHECK-RWT: sparse_tensor.sort  hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map}
 //       CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]]
 //       CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]])
 //       CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor<?x?x?xf32, #{{.*}}>>):
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 3c33ba42f4d388f..2df4237efa0bbe5 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -794,7 +794,7 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
 
 func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
   // expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
-  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref<?xf32>
+  sparse_tensor.sort insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref<?xf32>
   return
 }
 
@@ -805,7 +805,7 @@ func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
 func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
   %i20 = arith.constant 20 : index
   // expected-error at +1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}}
-  sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex>
+  sparse_tensor.sort hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex>
   return
 }
 
@@ -816,7 +816,7 @@ func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
 func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
   %i20 = arith.constant 20 : index
   // expected-error at +1 {{Expected dimension(y) >= n got 10 < 20}}
-  sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+  sparse_tensor.sort insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
   return
 }
 
@@ -826,7 +826,7 @@ func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10
 
 func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
   // expected-error at +1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}}
-  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref<?xindex>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref<?xindex>
   return %arg1 : memref<?xindex>
 }
 
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 317a0735fb8bf84..82267be34b93846 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -612,10 +612,10 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
 // CHECK-LABEL: func @sparse_sort_coo(
 //  CHECK-SAME: %[[A:.*]]: index,
 //  CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-//       CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref<?xindex>
+//       CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref<?xindex>
 //       CHECK: return %[[B]]
 func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
-  sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xindex>
+  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xindex>
   return %arg1 : memref<?xindex>
 }
 
@@ -627,9 +627,9 @@ func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xin
 //  CHECK-SAME: %[[A:.*]]: index,
 //  CHECK-SAME: %[[B:.*]]: memref<?xi64>,
 //  CHECK-SAME: %[[C:.*]]: memref<?xf32>)
-//       CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}}
+//       CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}}
 //       CHECK: return %[[B]], %[[C]]
 func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
-  sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
+  sparse_tensor.sort insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
   return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
 }
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 4fbd36906eacf0b..12d443ca679207e 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -116,7 +116,7 @@
 // CHECK:               } {"Emitted from" = "linalg.generic"}
 // CHECK:               scf.yield %[[VAL_64:.*]] : index
 // CHECK:             } {"Emitted from" = "linalg.generic"}
-// CHECK:             sparse_tensor.sort_coo  hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]]
+// CHECK:             sparse_tensor.sort  hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]]
 // CHECK:             %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
 // CHECK:               %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
 // CHECK:               %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 394b9a8448b5438..8df0f3208251988 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -111,7 +111,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
+    sparse_tensor.sort quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     // Dumps memory in the same order as the perm_map such that the output is ordered.
     %x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
@@ -140,7 +140,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
+    sparse_tensor.sort insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x1v2 : vector<5xi32>
@@ -168,7 +168,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
+    sparse_tensor.sort heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x1v3 : vector<5xi32>



More information about the Mlir-commits mailing list