[Mlir-commits] [mlir] 8550aeb - [mlir][sparse] Extend sorting function generator to support operand beyond (lo, hi, xs, ys).
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 31 10:22:33 PST 2023
Author: bixia1
Date: 2023-01-31T10:22:28-08:00
New Revision: 8550aebd57d08d34c11ba438d07e1e0942b97f31
URL: https://github.com/llvm/llvm-project/commit/8550aebd57d08d34c11ba438d07e1e0942b97f31
DIFF: https://github.com/llvm/llvm-project/commit/8550aebd57d08d34c11ba438d07e1e0942b97f31.diff
LOG: [mlir][sparse] Extend sorting function generator to support operand beyond (lo, hi, xs, ys).
This is to prepare for implementing a hybrid quick sort, which switches to heap
sort when the recursive depth exceeds certain limits.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142731
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 90dadf71e61b6..5d6f4212a47f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -43,8 +43,8 @@ static constexpr const char kSortNonstableFuncNamePrefix[] =
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";
-using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
- uint64_t, uint64_t, bool)>;
+using FuncGeneratorType = function_ref<void(
+ OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
/// Constructs a function name with this format to facilitate quick sort:
/// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
@@ -69,15 +69,21 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
/// parameters `nx` and `ny` tell the number of x and y values provided
/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+//
+// All sorting function generators take (lo, hi, xs, ys) in `operands` as
+// parameters for the sorting functions. Other parameters, such as the recursive
+// call depth, are appended to the end of the parameter list as
+// "trailing parameters".
static FlatSymbolRefAttr
getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
TypeRange resultTypes, StringRef namePrefix,
uint64_t nx, uint64_t ny, bool isCoo,
- ValueRange operands, FuncGeneratorType createFunc) {
+ ValueRange operands, FuncGeneratorType createFunc,
+ uint32_t nTrailingP = 0) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
- operands);
+ operands.drop_back(nTrailingP));
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
MLIRContext *context = module.getContext();
@@ -93,7 +99,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
- createFunc(builder, module, func, nx, ny, isCoo);
+ createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
}
return result;
@@ -242,7 +248,10 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
// and so on ...
static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo) {
+ bool isCoo, uint32_t nTrailingP = 0) {
+ // Compare functions don't use trailing parameters.
+ (void)nTrailingP;
+ assert(nTrailingP == 0);
createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
createEqCompare);
}
@@ -302,7 +311,10 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
// and so on ...
static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo) {
+ bool isCoo, uint32_t nTrailingP = 0) {
+ // Compare functions don't use trailing parameters.
+ (void)nTrailingP;
+ assert(nTrailingP == 0);
createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
createLessThanCompare);
}
@@ -322,7 +334,10 @@ static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo) {
+ bool isCoo, uint32_t nTrailingP = 0) {
+ // Binary search doesn't use trailing parameters.
+ (void)nTrailingP;
+ assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
@@ -330,7 +345,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value p = args[hiIdx];
- SmallVector<Type, 2> types(2, p.getType()); // only two
+ SmallVector<Type, 2> types(2, p.getType()); // Only two types.
scf::WhileOp whileOp = builder.create<scf::WhileOp>(
loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
@@ -363,7 +378,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
- compareOperands, createLessThanFunc);
+ compareOperands, createLessThanFunc, nTrailingP);
Value cond2 = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
@@ -560,7 +575,10 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo) {
+ bool isCoo, uint32_t nTrailingP = 0) {
+ // Quick sort partition doesn't use trailing parameters.
+ (void)nTrailingP;
+ assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
@@ -675,7 +693,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
// }
static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo) {
+ bool isCoo, uint32_t nTrailingP) {
+ (void)nTrailingP;
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
@@ -728,7 +747,10 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo) {
+ bool isCoo, uint32_t nTrailingP) {
+ // Stable sort function doesn't use trailing parameters.
+ (void)nTrailingP;
+ assert(nTrailingP == 0);
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
@@ -821,9 +843,10 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
: kSortNonstableFuncNamePrefix);
FuncGeneratorType funcGenerator =
isStable ? createSortStableFunc : createSortNonstableFunc;
+ uint32_t nTrailingP = 0;
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
- ny, isCoo, operands, funcGenerator);
+ ny, isCoo, operands, funcGenerator, nTrailingP);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
More information about the Mlir-commits
mailing list