[Mlir-commits] [mlir] a150766 - [mlir][sparse] Implement hybrid quick sort for sparse_tensor.sort.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 8 14:06:37 PST 2023
Author: bixia1
Date: 2023-02-08T14:06:31-08:00
New Revision: a1507668807e6108c12ffecf3740cb339b15018d
URL: https://github.com/llvm/llvm-project/commit/a1507668807e6108c12ffecf3740cb339b15018d
DIFF: https://github.com/llvm/llvm-project/commit/a1507668807e6108c12ffecf3740cb339b15018d.diff
LOG: [mlir][sparse] Implement hybrid quick sort for sparse_tensor.sort.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D143227
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
index b07991ef5f64e..12cfd3bdcca0b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
@@ -275,6 +275,11 @@ inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
return builder.create<arith::ConstantIndexOp>(loc, i);
}
+/// Generates a constant of `i64` type.
+inline Value constantI64(OpBuilder &builder, Location loc, int64_t i) {
+ return builder.create<arith::ConstantIntOp>(loc, i, 64);
+}
+
/// Generates a constant of `i32` type.
inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
return builder.create<arith::ConstantIntOp>(loc, i, 32);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 90ca39fe650d5..3e6157001266f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
@@ -38,12 +39,13 @@ static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
"_sparse_binary_search_";
-static constexpr const char kSortNonstableFuncNamePrefix[] =
- "_sparse_sort_nonstable_";
+static constexpr const char kHybridQuickSortFuncNamePrefix[] =
+ "_sparse_hybrid_qsort_";
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";
static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
+static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
using FuncGeneratorType = function_ref<void(
OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
@@ -916,41 +918,19 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc);
}
-/// Creates a function to perform quick sort on the value in the range of
-/// index [lo, hi).
-//
-// The generate IR corresponds to this C like algorithm:
-// void quickSort(lo, hi, data) {
-// if (lo < hi) {
-// p = partition(low, high, data);
-// quickSort(lo, p, data);
-// quickSort(p + 1, hi, data);
-// }
-// }
-static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP) {
- (void)nTrailingP;
- OpBuilder::InsertionGuard insertionGuard(builder);
- Block *entryBlock = func.addEntryBlock();
- builder.setInsertionPointToStart(entryBlock);
-
+static void createQuickSort(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, ValueRange args, uint64_t nx,
+ uint64_t ny, bool isCoo, uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
- ValueRange args = entryBlock->getArguments();
Value lo = args[loIdx];
Value hi = args[hiIdx];
- Value cond =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, lo, hi);
- scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
-
- // The if-stmt true branch.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
- ny, isCoo, args, createPartitionFunc);
- auto p = builder.create<func::CallOp>(
- loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
+ ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
+ auto p = builder.create<func::CallOp>(loc, partitionFunc,
+ TypeRange{IndexType::get(context)},
+ args.drop_back(nTrailingP));
SmallVector<Value> lowOperands{lo, p.getResult(0)};
lowOperands.append(args.begin() + xStartIdx, args.end());
@@ -962,10 +942,6 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
hi};
highOperands.append(args.begin() + xStartIdx, args.end());
builder.create<func::CallOp>(loc, func, highOperands);
-
- // After the if-stmt.
- builder.setInsertionPointAfter(ifOp);
- builder.create<func::ReturnOp>(loc);
}
/// Creates a function to perform insertion sort on the values in the range of
@@ -1054,6 +1030,116 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc);
}
+/// Creates a function to perform quick sort or a hybrid quick sort on the
+/// values in the range of index [lo, hi).
+//
+//
+// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
+// void quickSort(lo, hi, data) {
+// if (lo + 1 < hi) {
+// p = partition(low, high, data);
+// quickSort(lo, p, data);
+// quickSort(p + 1, hi, data);
+// }
+// }
+//
+// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
+// void hybridQuickSort(lo, hi, data, depthLimit) {
+// if (lo + 1 < hi) {
+// len = hi - lo;
+// if (len <= limit) {
+// insertionSort(lo, hi, data);
+// } else {
+// depthLimit --;
+// if (depthLimit <= 0) {
+// heapSort(lo, hi, data);
+// } else {
+// p = partition(low, high, data);
+// quickSort(lo, p, data);
+// quickSort(p + 1, hi, data);
+// }
+// depthLimit ++;
+// }
+// }
+// }
+//
+static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo, uint32_t nTrailingP) {
+ assert(nTrailingP == 1 || nTrailingP == 0);
+ bool isHybrid = (nTrailingP == 1);
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
+ Value lo = args[loIdx];
+ Value hi = args[hiIdx];
+ Value loCmp =
+ builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
+ Value cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+
+ // The if-stmt true branch.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value pDepthLimit;
+ Value savedDepthLimit;
+ scf::IfOp depthIf;
+
+ if (isHybrid) {
+ Value len = builder.create<arith::SubIOp>(loc, hi, lo);
+ Value lenLimit = constantIndex(builder, loc, 30);
+ Value lenCond = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, len, lenLimit);
+ scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
+
+ // When len <= limit.
+ builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
+ FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
+ builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
+ args.drop_back(nTrailingP), createSortStableFunc);
+ builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
+ ValueRange(args.drop_back(nTrailingP)));
+
+ // When len > limit.
+ builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
+ pDepthLimit = args.back();
+ savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit);
+ Value depthLimit = builder.create<arith::SubIOp>(
+ loc, savedDepthLimit, constantI64(builder, loc, 1));
+ builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+ Value depthCond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
+ depthLimit, constantI64(builder, loc, 0));
+ depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true);
+
+ // When depth exceeds limit.
+ builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
+ FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
+ builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
+ args.drop_back(nTrailingP), createHeapSortFunc);
+ builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
+ ValueRange(args.drop_back(nTrailingP)));
+
+ // When depth doesn't exceed limit.
+ builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
+ }
+
+ createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+
+ if (isHybrid) {
+ // Restore depthLimit.
+ builder.setInsertionPointAfter(depthIf);
+ builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
+ }
+
+ // After the if-stmt.
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<func::ReturnOp>(loc);
+}
+
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
@@ -1078,10 +1164,30 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
FuncGeneratorType funcGenerator;
uint32_t nTrailingP = 0;
switch (op.getAlgorithm()) {
- case SparseTensorSortKind::HybridQuickSort:
+ case SparseTensorSortKind::HybridQuickSort: {
+ funcName = kHybridQuickSortFuncNamePrefix;
+ funcGenerator = createQuickSortFunc;
+ nTrailingP = 1;
+ Value pDepthLimit = rewriter.create<memref::AllocaOp>(
+ loc, MemRefType::get({}, rewriter.getI64Type()));
+ operands.push_back(pDepthLimit);
+ // As a heuristics, set depthLimit = 2 * log2(n).
+ Value lo = operands[loIdx];
+ Value hi = operands[hiIdx];
+ Value len = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI64Type(),
+ rewriter.create<arith::SubIOp>(loc, hi, lo));
+ Value depthLimit = rewriter.create<arith::SubIOp>(
+ loc, constantI64(rewriter, loc, 64),
+ rewriter.create<math::CountLeadingZerosOp>(loc, len));
+ depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit,
+ constantI64(rewriter, loc, 1));
+ rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+ break;
+ }
case SparseTensorSortKind::QuickSort:
- funcName = kSortNonstableFuncNamePrefix;
- funcGenerator = createSortNonstableFunc;
+ funcName = kQuickSortFuncNamePrefix;
+ funcGenerator = createQuickSortFunc;
break;
case SparseTensorSortKind::InsertionSortStable:
funcName = kSortStableFuncNamePrefix;
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 68c8366eb822e..dbe0c972e6614 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -125,24 +125,25 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK: return %[[W:.*]]#2
// CHECK: }
-// CHECK-LABEL: func.func private @_sparse_sort_nonstable_1_i8_f32_index(
+// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index(
// CHECK-SAME: %[[L:arg0]]: index,
// CHECK-SAME: %[[H:.*]]: index,
// CHECK-SAME: %[[X0:.*]]: memref<?xi8>,
// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) {
// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]]
+// CHECK: %[[Lb:.*]] = arith.addi %[[L]], %[[C1]]
+// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]]
// CHECK: scf.if %[[COND]] {
// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
-// CHECK: func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
// CHECK: }
// CHECK: return
// CHECK: }
-// CHECK-LABEL: func.func @sparse_sort_1d2v(
+// CHECK-LABEL: func.func @sparse_sort_1d2v_quick(
// CHECK-SAME: %[[N:.*]]: index,
// CHECK-SAME: %[[X0:.*]]: memref<10xi8>,
// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
@@ -150,12 +151,12 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK: %[[C0:.*]] = arith.constant 0
// CHECK: %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
// CHECK: %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
-// CHECK: call @_sparse_sort_nonstable_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
+// CHECK: call @_sparse_qsort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
// CHECK: return %[[X0]], %[[Y0]], %[[Y1]]
// CHECK: }
-func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
+func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
-> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
+ sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
}
@@ -167,9 +168,28 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL: func.func @sparse_sort_3d
-func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL: func.func @sparse_sort_3d_quick
+func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+ sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+ return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}
+
+// -----
+
+// Only check the generated supporting function now. We have integration test
+// to verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
+// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: memref<i64>) {
+// CHECK-LABEL: func.func @sparse_sort_3d_hybrid
+func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
@@ -210,9 +230,28 @@ func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: mem
// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG: func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-LABEL: func.func @sparse_sort_coo
-func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL: func.func @sparse_sort_coo_quick
+func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+ sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: memref<i64>) {
+// CHECK-LABEL: func.func @sparse_sort_coo_hybrid
+func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 3c2d9cf62e5c8..d3ef2fa4ac325 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -49,8 +49,9 @@ module {
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
// Sort 0 elements.
+ // Quick sort.
// CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
+ sparse_tensor.sort quick_sort %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [10, 2, 0, 5, 1]
@@ -60,10 +61,15 @@ module {
// CHECK: [10, 2, 0, 5, 1]
sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+ // Hybrid sort.
+ // CHECK: [10, 2, 0, 5, 1]
+ sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
+ call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Sort the first 4 elements, with the last valid value untouched.
+ // Quick sort.
// CHECK: [0, 2, 5, 10, 1]
- sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
+ sparse_tensor.sort quick_sort %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [0, 2, 5, 10, 1]
@@ -77,6 +83,10 @@ module {
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+ // Hybrid sort.
+ // CHECK: [0, 2, 5, 10, 1]
+ sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
+ call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Prepare more buffers of
diff erent dimensions.
%x1s = memref.alloc() : memref<10xi32>
@@ -99,7 +109,7 @@ module {
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort hybrid_quick_sort %i5, %x0, %x1, %x2 jointly %y0
+ sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 46e1020f8d88e..70119f8cead15 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -92,7 +92,7 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo hybrid_quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v : vector<5xi32>
More information about the Mlir-commits
mailing list