[Mlir-commits] [mlir] a150766 - [mlir][sparse] Implement hybrid quick sort for sparse_tensor.sort.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 8 14:06:37 PST 2023


Author: bixia1
Date: 2023-02-08T14:06:31-08:00
New Revision: a1507668807e6108c12ffecf3740cb339b15018d

URL: https://github.com/llvm/llvm-project/commit/a1507668807e6108c12ffecf3740cb339b15018d
DIFF: https://github.com/llvm/llvm-project/commit/a1507668807e6108c12ffecf3740cb339b15018d.diff

LOG: [mlir][sparse] Implement hybrid quick sort for sparse_tensor.sort.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D143227

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index b07991ef5f64e..12cfd3bdcca0b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -275,6 +275,11 @@ inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
   return builder.create<arith::ConstantIndexOp>(loc, i);
 }
 
+/// Generates a constant of `i64` type.
+inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) {
+  return builder.create<arith::ConstantIntOp>(loc, i, 64);
+}
+
 /// Generates a constant of `i32` type.
 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
   return builder.create<arith::ConstantIntOp>(loc, i, 32);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 90ca39fe650d5..3e6157001266f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -38,12 +39,13 @@ static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
 static constexpr const char kBinarySearchFuncNamePrefix[] =
     "_sparse_binary_search_";
-static constexpr const char kSortNonstableFuncNamePrefix[] =
-    "_sparse_sort_nonstable_";
+static constexpr const char kHybridQuickSortFuncNamePrefix[] =
+    "_sparse_hybrid_qsort_";
 static constexpr const char kSortStableFuncNamePrefix[] =
     "_sparse_sort_stable_";
 static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
 static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
+static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
 
 using FuncGeneratorType = function_ref<void(
     OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
@@ -916,41 +918,19 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
   builder.create<func::ReturnOp>(loc);
 }
 
-/// Creates a function to perform quick sort on the value in the range of
-/// index [lo, hi).
-//
-// The generate IR corresponds to this C like algorithm:
-// void quickSort(lo, hi, data) {
-//   if (lo < hi) {
-//        p = partition(low, high, data);
-//        quickSort(lo, p, data);
-//        quickSort(p + 1, hi, data);
-//   }
-// }
-static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
-                                    func::FuncOp func, uint64_t nx, uint64_t ny,
-                                    bool isCoo, uint32_t nTrailingP) {
-  (void)nTrailingP;
-  OpBuilder::InsertionGuard insertionGuard(builder);
-  Block *entryBlock = func.addEntryBlock();
-  builder.setInsertionPointToStart(entryBlock);
-
+static void createQuickSort(OpBuilder &builder, ModuleOp module,
+                            func::FuncOp func, ValueRange args, uint64_t nx,
+                            uint64_t ny, bool isCoo, uint32_t nTrailingP) {
   MLIRContext *context = module.getContext();
   Location loc = func.getLoc();
-  ValueRange args = entryBlock->getArguments();
   Value lo = args[loIdx];
   Value hi = args[hiIdx];
-  Value cond =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
-  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
-
-  // The if-stmt true branch.
-  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
       builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
-      ny, isCoo, args, createPartitionFunc);
-  auto p = builder.create<func::CallOp>(
-      loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
+      ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
+  auto p = builder.create<func::CallOp>(loc, partitionFunc,
+                                        TypeRange{IndexType::get(context)},
+                                        args.drop_back(nTrailingP));
 
   SmallVector<Value> lowOperands{lo, p.getResult(0)};
   lowOperands.append(args.begin() + xStartIdx, args.end());
@@ -962,10 +942,6 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
       hi};
   highOperands.append(args.begin() + xStartIdx, args.end());
   builder.create<func::CallOp>(loc, func, highOperands);
-
-  // After the if-stmt.
-  builder.setInsertionPointAfter(ifOp);
-  builder.create<func::ReturnOp>(loc);
 }
 
 /// Creates a function to perform insertion sort on the values in the range of
@@ -1054,6 +1030,116 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
   builder.create<func::ReturnOp>(loc);
 }
 
+/// Creates a function to perform quick sort or a hybrid quick sort on the
+/// values in the range of index [lo, hi).
+//
+//
+// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
+// void quickSort(lo, hi, data) {
+//   if (lo + 1 < hi) {
+//        p = partition(low, high, data);
+//        quickSort(lo, p, data);
+//        quickSort(p + 1, hi, data);
+//   }
+// }
+//
+// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
+// void hybridQuickSort(lo, hi, data, depthLimit) {
+//   if (lo + 1 < hi) {
+//     len = hi - lo;
+//     if (len <= limit) {
+//       insertionSort(lo, hi, data);
+//     } else {
+//       depthLimit --;
+//       if (depthLimit <= 0) {
+//         heapSort(lo, hi, data);
+//       } else {
+//          p = partition(low, high, data);
+//          quickSort(lo, p, data);
+//          quickSort(p + 1, hi, data);
+//       }
+//       depthLimit ++;
+//     }
+//   }
+// }
+//
+static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
+                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                bool isCoo, uint32_t nTrailingP) {
+  assert(nTrailingP == 1 || nTrailingP == 0);
+  bool isHybrid = (nTrailingP == 1);
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value lo = args[loIdx];
+  Value hi = args[hiIdx];
+  Value loCmp =
+      builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
+  Value cond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+
+  // The if-stmt true branch.
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  Value pDepthLimit;
+  Value savedDepthLimit;
+  scf::IfOp depthIf;
+
+  if (isHybrid) {
+    Value len = builder.create<arith::SubIOp>(loc, hi, lo);
+    Value lenLimit = constantIndex(builder, loc, 30);
+    Value lenCond = builder.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ule, len, lenLimit);
+    scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
+
+    // When len <= limit.
+    builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
+    FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
+        builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
+        args.drop_back(nTrailingP), createSortStableFunc);
+    builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
+                                 ValueRange(args.drop_back(nTrailingP)));
+
+    // When len > limit.
+    builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
+    pDepthLimit = args.back();
+    savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit);
+    Value depthLimit = builder.create<arith::SubIOp>(
+        loc, savedDepthLimit, constantI64(builder, loc, 1));
+    builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+    Value depthCond =
+        builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
+                                      depthLimit, constantI64(builder, loc, 0));
+    depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true);
+
+    // When depth exceeds limit.
+    builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
+    FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
+        builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
+        args.drop_back(nTrailingP), createHeapSortFunc);
+    builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
+                                 ValueRange(args.drop_back(nTrailingP)));
+
+    // When depth doesn't exceed limit.
+    builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
+  }
+
+  createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+
+  if (isHybrid) {
+    // Restore depthLimit.
+    builder.setInsertionPointAfter(depthIf);
+    builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
+  }
+
+  // After the if-stmt.
+  builder.setInsertionPointAfter(ifOp);
+  builder.create<func::ReturnOp>(loc);
+}
+
 /// Implements the rewriting for operator sort and sort_coo.
 template <typename OpTy>
 LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
@@ -1078,10 +1164,30 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
   FuncGeneratorType funcGenerator;
   uint32_t nTrailingP = 0;
   switch (op.getAlgorithm()) {
-  case SparseTensorSortKind::HybridQuickSort:
+  case SparseTensorSortKind::HybridQuickSort: {
+    funcName = kHybridQuickSortFuncNamePrefix;
+    funcGenerator = createQuickSortFunc;
+    nTrailingP = 1;
+    Value pDepthLimit = rewriter.create<memref::AllocaOp>(
+        loc, MemRefType::get({}, rewriter.getI64Type()));
+    operands.push_back(pDepthLimit);
+    // As a heuristics, set depthLimit = 2 * log2(n).
+    Value lo = operands[loIdx];
+    Value hi = operands[hiIdx];
+    Value len = rewriter.create<arith::IndexCastOp>(
+        loc, rewriter.getI64Type(),
+        rewriter.create<arith::SubIOp>(loc, hi, lo));
+    Value depthLimit = rewriter.create<arith::SubIOp>(
+        loc, constantI64(rewriter, loc, 64),
+        rewriter.create<math::CountLeadingZerosOp>(loc, len));
+    depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit,
+                                                constantI64(rewriter, loc, 1));
+    rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+    break;
+  }
   case SparseTensorSortKind::QuickSort:
-    funcName = kSortNonstableFuncNamePrefix;
-    funcGenerator = createSortNonstableFunc;
+    funcName = kQuickSortFuncNamePrefix;
+    funcGenerator = createQuickSortFunc;
     break;
   case SparseTensorSortKind::InsertionSortStable:
     funcName = kSortStableFuncNamePrefix;

diff  --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 68c8366eb822e..dbe0c972e6614 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -125,24 +125,25 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 // CHECK:           return %[[W:.*]]#2
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_sparse_sort_nonstable_1_i8_f32_index(
+// CHECK-LABEL:   func.func private @_sparse_qsort_1_i8_f32_index(
 // CHECK-SAME:                                                   %[[L:arg0]]: index,
 // CHECK-SAME:                                                   %[[H:.*]]: index,
 // CHECK-SAME:                                                   %[[X0:.*]]: memref<?xi8>,
 // CHECK-SAME:                                                   %[[Y0:.*]]: memref<?xf32>,
 // CHECK-SAME:                                                   %[[Y1:.*]]: memref<?xindex>) {
 // CHECK:           %[[C1:.*]] = arith.constant 1
-// CHECK:           %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]]
+// CHECK:           %[[Lb:.*]] = arith.addi %[[L]], %[[C1]]
+// CHECK:           %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]]
 // CHECK:           scf.if %[[COND]] {
 // CHECK:             %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
 // CHECK:             %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
-// CHECK:             func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
 
-// CHECK-LABEL:   func.func @sparse_sort_1d2v(
+// CHECK-LABEL:   func.func @sparse_sort_1d2v_quick(
 // CHECK-SAME:                                %[[N:.*]]: index,
 // CHECK-SAME:                                %[[X0:.*]]: memref<10xi8>,
 // CHECK-SAME:                                %[[Y0:.*]]: memref<?xf32>,
@@ -150,12 +151,12 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
 // CHECK:           %[[C0:.*]] = arith.constant 0
 // CHECK:           %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
 // CHECK:           %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
-// CHECK:           call @_sparse_sort_nonstable_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
+// CHECK:           call @_sparse_qsort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
 // CHECK:           return %[[X0]], %[[Y0]], %[[Y1]]
 // CHECK:         }
-func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
+func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
    -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
-  sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
+  sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
 }
 
@@ -167,9 +168,28 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
 // CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL:   func.func @sparse_sort_3d
-func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+// CHECK-DAG:     func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL:   func.func @sparse_sort_3d_quick
+func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+  sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}
+
+// -----
+
+// Only check the generated supporting function now. We have integration test
+// to verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG:     func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
+// CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG:     func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: memref<i64>) {
+// CHECK-LABEL:   func.func @sparse_sort_3d_hybrid
+func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
   sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
 }
@@ -210,9 +230,28 @@ func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: mem
 // CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-LABEL:   func.func @sparse_sort_coo
-func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+// CHECK-DAG:     func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL:   func.func @sparse_sort_coo_quick
+func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+  sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG:     func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG:     func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: memref<i64>) {
+// CHECK-LABEL:   func.func @sparse_sort_coo_hybrid
+func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
   sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 3c2d9cf62e5c8..d3ef2fa4ac325 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -49,8 +49,9 @@ module {
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
 
     // Sort 0 elements.
+    // Quick sort.
     // CHECK: [10,  2,  0,  5,  1]
-    sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
+    sparse_tensor.sort quick_sort %i0, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     // Stable sort.
     // CHECK: [10,  2,  0,  5,  1]
@@ -60,10 +61,15 @@ module {
     // CHECK: [10,  2,  0,  5,  1]
     sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+    // Hybrid sort.
+    // CHECK: [10,  2,  0,  5,  1]
+    sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
+    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
 
     // Sort the first 4 elements, with the last valid value untouched.
+    // Quick sort.
     // CHECK: [0,  2,  5, 10,  1]
-    sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
+    sparse_tensor.sort quick_sort %i4, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     // Stable sort.
     // CHECK: [0,  2,  5,  10,  1]
@@ -77,6 +83,10 @@ module {
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+    // Hybrid sort.
+    // CHECK: [0,  2,  5, 10,  1]
+    sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
+    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
 
     // Prepare more buffers of 
diff erent dimensions.
     %x1s = memref.alloc() : memref<10xi32>
@@ -99,7 +109,7 @@ module {
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort hybrid_quick_sort %i5, %x0, %x1, %x2 jointly %y0
+    sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0
       : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
     call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 46e1020f8d88e..70119f8cead15 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -92,7 +92,7 @@ module {
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    sparse_tensor.sort_coo hybrid_quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+    sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
       : memref<?xi32> jointly memref<?xi32>
     %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
     vector.print %x0v : vector<5xi32>


        


More information about the Mlir-commits mailing list