[Mlir-commits] [mlir] 7cec4d1 - [mlir][sparse] Change the quick sort pivot selection.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 29 17:57:02 PST 2023
Author: bixia1
Date: 2023-01-29T17:56:57-08:00
New Revision: 7cec4d169d2c5261b484bb7dab276cfd7a4090db
URL: https://github.com/llvm/llvm-project/commit/7cec4d169d2c5261b484bb7dab276cfd7a4090db
DIFF: https://github.com/llvm/llvm-project/commit/7cec4d169d2c5261b484bb7dab276cfd7a4090db.diff
LOG: [mlir][sparse] Change the quick sort pivot selection.
Previously, we choose the value at (lo + hi)/2 as a pivot for partitioning the
data in [lo, hi). We now choose the median for the three values at lo, (lo +
hi)/2, and (hi-1) as a pivot to match the std::qsort implementation.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142679
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 73b5bd48b3f4b..3fc760ee756cb 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -443,6 +443,93 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
return std::make_pair(whileOp.getResult(0), compareEq);
}
+/// Creates a code block to swap the values so that data[mi] is the median among
+/// data[lo], data[hi], and data[mi].
+// The generated code corresponds to this C-like algorithm:
+// median = mi
+// if (data[mi] < data[lo]). (if1)
+// if (data[hi] < data[lo]) (if2)
+// median = data[hi] < data[mi] ? mi : hi
+// else
+// median = lo
+// else
+// if data[hi] < data[mi] (if3)
+// median = data[hi] < data[lo] ? lo : hi
+// if median != mi swap data[median] with data[mi]
+static void createChoosePivot(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo, Value lo, Value hi, Value mi,
+ ValueRange args) {
+ SmallVector<Value> compareOperands{mi, lo};
+ uint64_t numXBuffers = isCoo ? 1 : nx;
+ compareOperands.append(args.begin() + xStartIdx,
+ args.begin() + xStartIdx + numXBuffers);
+ Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
+ SmallVector<Type, 1> cmpTypes{i1Type};
+ FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+ builder, func, cmpTypes, kLessThanFuncNamePrefix, nx, ny, isCoo,
+ compareOperands, createLessThanFunc);
+ Location loc = func.getLoc();
+ // Compare data[mi] < data[lo].
+ Value cond1 =
+ builder.create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
+ .getResult(0);
+ SmallVector<Type, 1> ifTypes{lo.getType()};
+ scf::IfOp ifOp1 =
+ builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
+
+ // Generate an if-stmt to find the median value, assuming we already know that
+ // data[b] < data[a] and we haven't compare data[c] yet.
+ auto createFindMedian = [&](Value a, Value b, Value c) -> scf::IfOp {
+ compareOperands[0] = c;
+ compareOperands[1] = a;
+ // Compare data[c]] < data[a].
+ Value cond2 =
+ builder
+ .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
+ .getResult(0);
+ scf::IfOp ifOp2 =
+ builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
+ builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+ compareOperands[0] = c;
+ compareOperands[1] = b;
+ // Compare data[c] < data[b].
+ Value cond3 =
+ builder
+ .create<func::CallOp>(loc, lessThanFunc, cmpTypes, compareOperands)
+ .getResult(0);
+ builder.create<scf::YieldOp>(
+ loc, ValueRange{builder.create<arith::SelectOp>(loc, cond3, b, c)});
+ builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{a});
+ return ifOp2;
+ };
+
+ builder.setInsertionPointToStart(&ifOp1.getThenRegion().front());
+ scf::IfOp ifOp2 = createFindMedian(lo, mi, hi);
+ builder.setInsertionPointAfter(ifOp2);
+ builder.create<scf::YieldOp>(loc, ValueRange{ifOp2.getResult(0)});
+
+ builder.setInsertionPointToStart(&ifOp1.getElseRegion().front());
+ scf::IfOp ifOp3 = createFindMedian(mi, lo, hi);
+
+ builder.setInsertionPointAfter(ifOp3);
+ builder.create<scf::YieldOp>(loc, ValueRange{ifOp3.getResult(0)});
+
+ builder.setInsertionPointAfter(ifOp1);
+ Value median = ifOp1.getResult(0);
+ Value cond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, mi, median);
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, TypeRange(), cond, /*else=*/false);
+
+ SmallVector<Value> swapOperands{median, mi};
+ swapOperands.append(args.begin() + xStartIdx, args.end());
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ builder.setInsertionPointAfter(ifOp);
+}
+
/// Creates a function to perform quick sort partition on the values in the
/// range of index [lo, hi), assuming lo < hi.
//
@@ -489,7 +576,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
- SmallVector<Value, 3> operands{i, j, p}; // exactly three
+ createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
+ SmallVector<Value, 3> operands{i, j, p}; // Exactly three values.
SmallVector<Type, 3> types{i.getType(), j.getType(), p.getType()};
scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index c9ee528735cf5..3dc9d5361a22a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -80,7 +80,7 @@ module {
// CHECK: [1, 1, 2, 5, 10]
// CHECK: [3, 3, 1, 10, 1
// CHECK: [9, 9, 4, 7, 2
- // CHECK: [8, 7, 10, 9, 6
+ // CHECK: [7, 8, 10, 9, 6
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 77151782cbe77..7c27a076837d8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -80,8 +80,8 @@ module {
// CHECK: ( 1, 1, 2, 5, 10 )
// CHECK: ( 3, 3, 1, 10, 1 )
// CHECK: ( 9, 9, 4, 7, 2 )
- // CHECK: ( 8, 7, 10, 9, 6 )
- // CHECK: ( 4, 7, 7, 9, 5 )
+ // CHECK: ( 7, 8, 10, 9, 6 )
+ // CHECK: ( 7, 4, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
More information about the Mlir-commits
mailing list