[Mlir-commits] [mlir] 7cec4d1 - [mlir][sparse] Change the quick sort pivot selection.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 29 17:57:02 PST 2023


Author: bixia1
Date: 2023-01-29T17:56:57-08:00
New Revision: 7cec4d169d2c5261b484bb7dab276cfd7a4090db

URL: https://github.com/llvm/llvm-project/commit/7cec4d169d2c5261b484bb7dab276cfd7a4090db
DIFF: https://github.com/llvm/llvm-project/commit/7cec4d169d2c5261b484bb7dab276cfd7a4090db.diff

LOG: [mlir][sparse] Change the quick sort pivot selection.

Previously, we choose the value at (lo + hi)/2 as a pivot for partitioning the
data in [lo, hi). We now choose the median for the three values at lo, (lo +
hi)/2, and (hi-1) as a pivot to match the std::qsort implementation.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142679

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 73b5bd48b3f4b..3fc760ee756cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -443,6 +443,93 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
   return std::make_pair(whileOp.getResult(0), compareEq);
 }
 
+/// Creates a code block to swap the values so that data[mi] is the median among
+/// data[lo], data[hi], and data[mi].
+//  The generated code corresponds to this C-like algorithm:
+//  median = mi
+//  if (data[mi] < data[lo]).                               (if1)
+//    if (data[hi] < data[lo])                              (if2)
+//       median = data[hi] < data[mi] ? mi : hi
+//    else
+//       median = lo
+//  else
+//    if data[hi] < data[mi]                                (if3)
+//      median = data[hi] < data[lo] ? lo : hi
+//  if median != mi swap data[median] with data[mi]
+static void createChoosePivot(OpBuilder &builder, ModuleOp module,
+                              func::FuncOp func, uint64_t nx, uint64_t ny,
+                              bool isCoo, Value lo, Value hi, Value mi,
+                              ValueRange args) {
+  SmallVector<Value> compareOperands{mi, lo};
+  uint64_t numXBuffers = isCoo ? 1 : nx;
+  compareOperands.append(args.begin() + xStartIdx,
+                         args.begin() + xStartIdx + numXBuffers);
+  Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
+  SmallVector<Type, 1> cmpTypes{i1Type};
+  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+      builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
+      compareOperands, createLessThanFunc);
+  Location loc = func.getLoc();
+  // Compare data[mi] < data[lo].
+  Value cond1 =
+      builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
+          .getResult(0);
+  SmallVector<Type, 1> ifTypes{lo.getType()};
+  scf::IfOp ifOp1 =
+      builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
+
+  // Generate an if-stmt to find the median value, assuming we already know that
+  // data[b] < data[a] and we haven't compare data[c] yet.
+  auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
+    compareOperands[0] = c;
+    compareOperands[1] = a;
+    // Compare data[c]] < data[a].
+    Value cond2 =
+        builder
+            .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
+            .getResult(0);
+    scf::IfOp ifOp2 =
+        builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
+    builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+    compareOperands[0] = c;
+    compareOperands[1] = b;
+    // Compare data[c] < data[b].
+    Value cond3 =
+        builder
+            .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
+            .getResult(0);
+    builder.create<scf::YieldOp>(
+        loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
+    builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+    builder.create<scf::YieldOp>(loc, ValueRange{a});
+    return ifOp2;
+  };
+
+  builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
+  scf::IfOp ifOp2 = createFindMedian(lo, mi, hi);
+  builder.setInsertionPointAfter(ifOp2);
+  builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)});
+
+  builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
+  scf::IfOp ifOp3 = createFindMedian(mi, lo, hi);
+
+  builder.setInsertionPointAfter(ifOp3);
+  builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)});
+
+  builder.setInsertionPointAfter(ifOp1);
+  Value median = ifOp1.getResult(0);
+  Value cond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median);
+  scf::IfOp ifOp =
+      builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false);
+
+  SmallVector<Value> swapOperands{median, mi};
+  swapOperands.append(args.begin() + xStartIdx, args.end());
+  builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  builder.setInsertionPointAfter(ifOp);
+}
+
 /// Creates a function to perform quick sort partition on the values in the
 /// range of index [lo, hi), assuming lo < hi.
 //
@@ -489,7 +576,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
 
   Value i = lo;
   Value j = builder.create<arith::SubIOp>(loc, hi, c1);
-  SmallVector<Value, 3> operands{i, j, p}; // exactly three
+  createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
+  SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
   SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
   scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index c9ee528735cf5..3dc9d5361a22a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -80,7 +80,7 @@ module {
     // CHECK: [1,  1,  2,  5,  10]
     // CHECK: [3,  3,  1,  10,  1
     // CHECK: [9,  9,  4,  7,  2
-    // CHECK: [8,  7,  10,  9,  6
+    // CHECK: [7,  8,  10,  9,  6
     call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 77151782cbe77..7c27a076837d8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -80,8 +80,8 @@ module {
     // CHECK: ( 1, 1, 2, 5, 10 )
     // CHECK: ( 3, 3, 1, 10, 1 )
     // CHECK: ( 9, 9, 4, 7, 2 )
-    // CHECK: ( 8, 7, 10, 9, 6 )
-    // CHECK: ( 4, 7, 7, 9, 5 )
+    // CHECK: ( 7, 8, 10, 9, 6 )
+    // CHECK: ( 7, 4, 7, 9, 5 )
     call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)


        


More information about the Mlir-commits mailing list