[Mlir-commits] [mlir] 8550aeb - [mlir][sparse] Extend sorting function generator to support operand beyond (lo, hi, xs, ys).

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 31 10:22:33 PST 2023


Author: bixia1
Date: 2023-01-31T10:22:28-08:00
New Revision: 8550aebd57d08d34c11ba438d07e1e0942b97f31

URL: https://github.com/llvm/llvm-project/commit/8550aebd57d08d34c11ba438d07e1e0942b97f31
DIFF: https://github.com/llvm/llvm-project/commit/8550aebd57d08d34c11ba438d07e1e0942b97f31.diff

LOG: [mlir][sparse] Extend sorting function generator to support operand beyond (lo, hi, xs, ys).

This is to prepare for implementing a hybrid quick sort, which switches to heap
sort when the recursive depth exceeds certain limits.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142731

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 90dadf71e61b6..5d6f4212a47f4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -43,8 +43,8 @@ static constexpr const char kSortNonstableFuncNamePrefix[] =
 static constexpr const char kSortStableFuncNamePrefix[] =
     "_sparse_sort_stable_";
 
-using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
-                                            uint64_t, uint64_t, bool)>;
+using FuncGeneratorType = function_ref<void(
+    OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
 
 /// Constructs a function name with this format to facilitate quick sort:
 ///   <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
@@ -69,15 +69,21 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
 /// parameters `nx` and `ny` tell the number of x and y values provided
 /// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
 /// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+//
+// All sorting function generators take (lo, hi, xs, ys) in `operands` as
+// parameters for the sorting functions. Other parameters, such as the recursive
+// call depth, are appended to the end of the parameter list as
+// "trailing parameters".
 static FlatSymbolRefAttr
 getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
                          TypeRange resultTypes, StringRef namePrefix,
                          uint64_t nx, uint64_t ny, bool isCoo,
-                         ValueRange operands, FuncGeneratorType createFunc) {
+                         ValueRange operands, FuncGeneratorType createFunc,
+                         uint32_t nTrailingP = 0) {
   SmallString<32> nameBuffer;
   llvm::raw_svector_ostream nameOstream(nameBuffer);
   getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
-                               operands);
+                               operands.drop_back(nTrailingP));
 
   ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
   MLIRContext *context = module.getContext();
@@ -93,7 +99,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
         loc, nameOstream.str(),
         FunctionType::get(context, operands.getTypes(), resultTypes));
     func.setPrivate();
-    createFunc(builder, module, func, nx, ny, isCoo);
+    createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
   }
 
   return result;
@@ -242,7 +248,10 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
 //       and so on ...
 static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
                                 func::FuncOp func, uint64_t nx, uint64_t ny,
-                                bool isCoo) {
+                                bool isCoo, uint32_t nTrailingP = 0) {
+  // Compare functions don't use trailing parameters.
+  (void)nTrailingP;
+  assert(nTrailingP == 0);
   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
                                   createEqCompare);
 }
@@ -302,7 +311,10 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
 //       and so on ...
 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
                                func::FuncOp func, uint64_t nx, uint64_t ny,
-                               bool isCoo) {
+                               bool isCoo, uint32_t nTrailingP = 0) {
+  // Compare functions don't use trailing parameters.
+  (void)nTrailingP;
+  assert(nTrailingP == 0);
   createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
                                   createLessThanCompare);
 }
@@ -322,7 +334,10 @@ static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
 //
 static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
                                    func::FuncOp func, uint64_t nx, uint64_t ny,
-                                   bool isCoo) {
+                                   bool isCoo, uint32_t nTrailingP = 0) {
+  // Binary search doesn't use trailing parameters.
+  (void)nTrailingP;
+  assert(nTrailingP == 0);
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -330,7 +345,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
   Location loc = func.getLoc();
   ValueRange args = entryBlock->getArguments();
   Value p = args[hiIdx];
-  SmallVector<Type, 2> types(2, p.getType()); // only two
+  SmallVector<Type, 2> types(2, p.getType()); // Only two types.
   scf::WhileOp whileOp = builder.create<scf::WhileOp>(
       loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
 
@@ -363,7 +378,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
   Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
   FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
       builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
-      compareOperands, createLessThanFunc);
+      compareOperands, createLessThanFunc, nTrailingP);
   Value cond2 = builder
                     .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
                                           compareOperands)
@@ -560,7 +575,10 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
 //   }
 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, uint64_t nx, uint64_t ny,
-                                bool isCoo) {
+                                bool isCoo, uint32_t nTrailingP = 0) {
+  // Quick sort partition doesn't use trailing parameters.
+  (void)nTrailingP;
+  assert(nTrailingP == 0);
   OpBuilder::InsertionGuard insertionGuard(builder);
 
   Block *entryBlock = func.addEntryBlock();
@@ -675,7 +693,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
 // }
 static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
                                     func::FuncOp func, uint64_t nx, uint64_t ny,
-                                    bool isCoo) {
+                                    bool isCoo, uint32_t nTrailingP) {
+  (void)nTrailingP;
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -728,7 +747,10 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
 // }
 static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
                                  func::FuncOp func, uint64_t nx, uint64_t ny,
-                                 bool isCoo) {
+                                 bool isCoo, uint32_t nTrailingP) {
+  // Stable sort function doesn't use trailing parameters.
+  (void)nTrailingP;
+  assert(nTrailingP == 0);
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -821,9 +843,10 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
                                     : kSortNonstableFuncNamePrefix);
   FuncGeneratorType funcGenerator =
       isStable ? createSortStableFunc : createSortNonstableFunc;
+  uint32_t nTrailingP = 0;
   FlatSymbolRefAttr func =
       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
-                               ny, isCoo, operands, funcGenerator);
+                               ny, isCoo, operands, funcGenerator, nTrailingP);
   rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
   return success();
 }


        


More information about the Mlir-commits mailing list