[Mlir-commits] [mlir] f6424d1 - [mlir][sparse] Improve quick sort by using a loop to sort the bigger partition.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 10 20:43:14 PST 2023
Author: bixia1
Date: 2023-03-10T20:43:08-08:00
New Revision: f6424d11cb3f6d1ece0e3a4633abfd8427d463ff
URL: https://github.com/llvm/llvm-project/commit/f6424d11cb3f6d1ece0e3a4633abfd8427d463ff
DIFF: https://github.com/llvm/llvm-project/commit/f6424d11cb3f6d1ece0e3a4633abfd8427d463ff.diff
LOG: [mlir][sparse] Improve quick sort by using a loop to sort the bigger partition.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D145440
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 107d9ef7569b..b8cf62366d25 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -918,9 +918,13 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc);
}
-static void createQuickSort(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, ValueRange args, uint64_t nx,
- uint64_t ny, bool isCoo, uint32_t nTrailingP) {
+/// A helper for generating code to perform quick sort. It partitions [lo, hi),
+/// recursively calls quick sort to process the smaller partition and returns
+/// the bigger partition to be processed by the enclosed while-loop.
+static std::pair<Value, Value>
+createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
+ ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
+ uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
Value lo = args[loIdx];
@@ -928,20 +932,45 @@ static void createQuickSort(OpBuilder &builder, ModuleOp module,
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
- auto p = builder.create<func::CallOp>(loc, partitionFunc,
- TypeRange{IndexType::get(context)},
- args.drop_back(nTrailingP));
-
- SmallVector<Value> lowOperands{lo, p.getResult(0)};
- lowOperands.append(args.begin() + xStartIdx, args.end());
- builder.create<func::CallOp>(loc, func, lowOperands);
-
- SmallVector<Value> highOperands{
- builder.create<arith::AddIOp>(loc, p.getResult(0),
- constantIndex(builder, loc, 1)),
- hi};
- highOperands.append(args.begin() + xStartIdx, args.end());
- builder.create<func::CallOp>(loc, func, highOperands);
+ Value p = builder
+ .create<func::CallOp>(loc, partitionFunc,
+ TypeRange{IndexType::get(context)},
+ args.drop_back(nTrailingP))
+ .getResult(0);
+ Value pP1 =
+ builder.create<arith::AddIOp>(loc, p, constantIndex(builder, loc, 1));
+ Value lenLow = builder.create<arith::SubIOp>(loc, p, lo);
+ Value lenHigh = builder.create<arith::SubIOp>(loc, hi, p);
+ Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
+ lenLow, lenHigh);
+
+ SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
+
+ Value c0 = constantIndex(builder, loc, 0);
+ auto mayRecursion = [&](Value low, Value high, Value len) {
+ Value cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, len, c0);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ SmallVector<Value> operands{low, high};
+ operands.append(args.begin() + xStartIdx, args.end());
+ builder.create<func::CallOp>(loc, func, operands);
+ builder.setInsertionPointAfter(ifOp);
+ };
+
+ // Recursively call quickSort to process the smaller partition and return
+ // the bigger partition to be processed by the enclosed while-loop.
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ mayRecursion(lo, p, lenLow);
+ builder.create<scf::YieldOp>(loc, ValueRange{pP1, hi});
+
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ mayRecursion(pP1, hi, lenHigh);
+ builder.create<scf::YieldOp>(loc, ValueRange{lo, p});
+
+ builder.setInsertionPointAfter(ifOp);
+ return std::make_pair(ifOp.getResult(0), ifOp.getResult(1));
}
/// Creates a function to perform insertion sort on the values in the range of
@@ -1036,16 +1065,21 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
//
// When nTrailingP == 0, the generated IR corresponds to this C like algorithm:
// void quickSort(lo, hi, data) {
-// if (lo + 1 < hi) {
+// while (lo + 1 < hi) {
// p = partition(low, high, data);
-// quickSort(lo, p, data);
-// quickSort(p + 1, hi, data);
+// if (len(lo, p) < len(p+1, hi)) {
+// quickSort(lo, p, data);
+// lo = p+1;
+// } else {
+// quickSort(p + 1, hi, data);
+// hi = p;
+// }
// }
// }
//
// When nTrailingP == 1, the generated IR corresponds to this C like algorithm:
// void hybridQuickSort(lo, hi, data, depthLimit) {
-// if (lo + 1 < hi) {
+// while (lo + 1 < hi) {
// len = hi - lo;
// if (len <= limit) {
// insertionSort(lo, hi, data);
@@ -1055,10 +1089,14 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
// heapSort(lo, hi, data);
// } else {
// p = partition(low, high, data);
-// quickSort(lo, p, data);
-// quickSort(p + 1, hi, data);
+// if (len(lo, p) < len(p+1, hi)) {
+// quickSort(lo, p, data, depthLimit);
+// lo = p+1;
+// } else {
+// quickSort(p + 1, hi, data, depthLimit);
+// hi = p;
+// }
// }
-// depthLimit ++;
// }
// }
// }
@@ -1073,70 +1111,98 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(entryBlock);
Location loc = func.getLoc();
- ValueRange args = entryBlock->getArguments();
+ SmallVector<Value> args;
+ args.append(entryBlock->getArguments().begin(),
+ entryBlock->getArguments().end());
Value lo = args[loIdx];
Value hi = args[hiIdx];
- Value loCmp =
+ SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
+ scf::WhileOp whileOp =
+ builder.create<scf::WhileOp>(loc, types, SmallVector<Value, 2>{lo, hi});
+
+ // The before-region of the WhileOp.
+ Block *before =
+ builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
+ builder.setInsertionPointToEnd(before);
+ lo = before->getArgument(0);
+ hi = before->getArgument(1);
+ Value loP1 =
builder.create<arith::AddIOp>(loc, lo, constantIndex(builder, loc, 1));
- Value cond =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loCmp, hi);
- scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+ Value needSort =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, loP1, hi);
+ builder.create<scf::ConditionOp>(loc, needSort, before->getArguments());
- // The if-stmt true branch.
- builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value pDepthLimit;
- Value savedDepthLimit;
- scf::IfOp depthIf;
+ // The after-region of the WhileOp.
+ Block *after =
+ builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
+ builder.setInsertionPointToEnd(after);
+ lo = after->getArgument(0);
+ hi = after->getArgument(1);
+ args[0] = lo;
+ args[1] = hi;
if (isHybrid) {
Value len = builder.create<arith::SubIOp>(loc, hi, lo);
Value lenLimit = constantIndex(builder, loc, 30);
Value lenCond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ule, len, lenLimit);
- scf::IfOp lenIf = builder.create<scf::IfOp>(loc, lenCond, /*else=*/true);
+ scf::IfOp lenIf =
+ builder.create<scf::IfOp>(loc, types, lenCond, /*else=*/true);
// When len <= limit.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
- args.drop_back(nTrailingP), createSortStableFunc);
+ ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
- ValueRange(args.drop_back(nTrailingP)));
+ ValueRange(args).drop_back(nTrailingP));
+ builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
// When len > limit.
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
- pDepthLimit = args.back();
- savedDepthLimit = builder.create<memref::LoadOp>(loc, pDepthLimit);
- Value depthLimit = builder.create<arith::SubIOp>(
- loc, savedDepthLimit, constantI64(builder, loc, 1));
- builder.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+ Value depthLimit = args.back();
+ depthLimit = builder.create<arith::SubIOp>(loc, depthLimit,
+ constantI64(builder, loc, 1));
Value depthCond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
depthLimit, constantI64(builder, loc, 0));
- depthIf = builder.create<scf::IfOp>(loc, depthCond, /*else=*/true);
+ scf::IfOp depthIf =
+ builder.create<scf::IfOp>(loc, types, depthCond, /*else=*/true);
// When depth exceeds limit.
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
- args.drop_back(nTrailingP), createHeapSortFunc);
+ ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
- ValueRange(args.drop_back(nTrailingP)));
+ ValueRange(args).drop_back(nTrailingP));
+ builder.create<scf::YieldOp>(loc, ValueRange{lo, lo});
// When depth doesn't exceed limit.
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
- }
-
- createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+ args.back() = depthLimit;
+ std::tie(lo, hi) =
+ createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+ builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
- if (isHybrid) {
- // Restore depthLimit.
builder.setInsertionPointAfter(depthIf);
- builder.create<memref::StoreOp>(loc, savedDepthLimit, pDepthLimit);
+ lo = depthIf.getResult(0);
+ hi = depthIf.getResult(1);
+ builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
+
+ builder.setInsertionPointAfter(lenIf);
+ lo = lenIf.getResult(0);
+ hi = lenIf.getResult(1);
+ } else {
+ std::tie(lo, hi) =
+ createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
}
- // After the if-stmt.
- builder.setInsertionPointAfter(ifOp);
+ // New [lo, hi) for the next while-loop iteration.
+ builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
+
+ // After the while-loop.
+ builder.setInsertionPointAfter(whileOp);
builder.create<func::ReturnOp>(loc);
}
@@ -1171,9 +1237,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
funcName = kHybridQuickSortFuncNamePrefix;
funcGenerator = createQuickSortFunc;
nTrailingP = 1;
- Value pDepthLimit = rewriter.create<memref::AllocaOp>(
- loc, MemRefType::get({}, rewriter.getI64Type()));
- operands.push_back(pDepthLimit);
// As a heuristics, set depthLimit = 2 * log2(n).
Value lo = operands[loIdx];
Value hi = operands[hiIdx];
@@ -1183,9 +1246,7 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
Value depthLimit = rewriter.create<arith::SubIOp>(
loc, constantI64(rewriter, loc, 64),
rewriter.create<math::CountLeadingZerosOp>(loc, len));
- depthLimit = rewriter.create<arith::ShLIOp>(loc, depthLimit,
- constantI64(rewriter, loc, 1));
- rewriter.create<memref::StoreOp>(loc, depthLimit, pDepthLimit);
+ operands.push_back(depthLimit);
break;
}
case SparseTensorSortKind::QuickSort:
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index dbe0c972e661..68e5c9b96b94 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -132,13 +132,24 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) {
// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[Lb:.*]] = arith.addi %[[L]], %[[C1]]
-// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H]]
-// CHECK: scf.if %[[COND]] {
-// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
-// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: scf.while (%[[L2:.*]] = %[[L]], %[[H2:.*]] = %[[H]])
+// CHECK: %[[Lb:.*]] = arith.addi %[[L2]], %[[C1]]
+// CHECK: %[[COND:.*]] = arith.cmpi ult, %[[Lb]], %[[H2]]
+// CHECK: scf.condition(%[[COND]]) %[[L2]], %[[H2]]
+// CHECK: } do {
+// CHECK: ^bb0(%[[L3:.*]]: index, %[[H3:.*]]: index)
+// CHECK: %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L3]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: %[[PP1:.*]] = arith.addi %[[P]], %[[C1]] : index
+// CHECK: %[[LenL:.*]] = arith.subi %[[P]], %[[L3]]
+// CHECK: %[[LenH:.*]] = arith.subi %[[H3]], %[[P]]
+// CHECK: %[[Cmp:.*]] = arith.cmpi ule, %[[LenL]], %[[LenH]]
+// CHECK: %[[L4:.*]] = arith.select %[[Cmp]], %[[PP1]], %[[L3]]
+// CHECK: %[[H4:.*]] = arith.select %[[Cmp]], %[[H3]], %[[P]]
+// CHECK: scf.if %[[Cmp]]
+// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[L3]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: else
+// CHECK: func.call @_sparse_qsort_1_i8_f32_index(%[[PP1]], %[[H3]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK: scf.yield %[[L4]], %[[H4]]
// CHECK: }
// CHECK: return
// CHECK: }
@@ -187,7 +198,7 @@ func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: me
// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: memref<i64>) {
+// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_3d_hybrid
func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
@@ -249,7 +260,7 @@ func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2:
// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: memref<i64>) {
+// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_coo_hybrid
func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
More information about the Mlir-commits
mailing list