[Mlir-commits] [mlir] 9b800bf - [mlir][sparse] Improve the non-stable sort implementation.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 7 07:38:48 PST 2022
Author: bixia1
Date: 2022-11-07T07:38:42-08:00
New Revision: 9b800bf79d9d2fa18ed5be891346155238015515
URL: https://github.com/llvm/llvm-project/commit/9b800bf79d9d2fa18ed5be891346155238015515
DIFF: https://github.com/llvm/llvm-project/commit/9b800bf79d9d2fa18ed5be891346155238015515.diff
LOG: [mlir][sparse] Improve the non-stable sort implementation.
Replace the quick sort partition method with one that is more similar to the
method used by C++ std quick sort. This improves the runtime for sorting
sk_2005.mtx by more than 10x.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137290
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 929d4a4ddf1f3..0af92a656d848 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -33,8 +33,8 @@ static constexpr uint64_t loIdx = 0;
static constexpr uint64_t hiIdx = 1;
static constexpr uint64_t xStartIdx = 2;
-static constexpr const char kMaySwapFuncNamePrefix[] = "_sparse_may_swap_";
static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
+static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
static constexpr const char kBinarySearchFuncNamePrefix[] =
"_sparse_binary_search_";
@@ -90,11 +90,10 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
return result;
}
-/// Creates a function for swapping the values in index i and j for all the
+/// Creates a code block for swapping the values in index i and j for all the
/// buffers.
//
-// The generate IR corresponds to this C like algorithm:
-// if (i != j) {
+// The generated IR corresponds to this C like algorithm:
// swap(x0[i], x0[j]);
// swap(x1[i], x1[j]);
// ...
@@ -102,36 +101,90 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
// swap(y0[i], y0[j]);
// ...
// swap(yn[i], yn[j]);
-// }
-static void createMaySwapFunc(OpBuilder &builder, ModuleOp unused,
- func::FuncOp func, size_t dim) {
+static void createSwap(OpBuilder &builder, Location loc, ValueRange args) {
+ Value i = args[0];
+ Value j = args[1];
+ for (auto arg : args.drop_front(xStartIdx)) {
+ Value vi = builder.create<memref::LoadOp>(loc, arg, i);
+ Value vj = builder.create<memref::LoadOp>(loc, arg, j);
+ builder.create<memref::StoreOp>(loc, vj, arg, i);
+ builder.create<memref::StoreOp>(loc, vi, arg, j);
+ }
+}
+
+/// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
+/// compare each pair is create via `compareBuilder`.
+static void createCompareFuncImplementation(
+ OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim,
+ function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
+ compareBuilder) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
-
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
- Value i = args[0];
- Value j = args[1];
+
+ scf::IfOp topIfOp;
+ for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
+ scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1],
+ item.value(), (item.index() == dim - 1));
+ if (item.index() == 0) {
+ topIfOp = ifOp;
+ } else {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointAfter(ifOp);
+ builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+ }
+ }
+
+ builder.setInsertionPointAfter(topIfOp);
+ builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+}
+
+/// Generates an if-statement to compare whether x[i] is equal to x[j].
+static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
+ Value j, Value x, bool isLastDim) {
+ Value f = constantI1(builder, loc, false);
+ Value t = constantI1(builder, loc, true);
+ Value vi = builder.create<memref::LoadOp>(loc, x, i);
+ Value vj = builder.create<memref::LoadOp>(loc, x, j);
+
Value cond =
- builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, i, j);
- scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
+ scf::IfOp ifOp =
+ builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
- // If i!=j swap values in the buffers.
+ // x[1] != x[j]:
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, f);
+
+ // x[i] == x[j]:
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- for (auto arg : args.drop_front(xStartIdx)) {
- Value vi = builder.create<memref::LoadOp>(loc, arg, i);
- Value vj = builder.create<memref::LoadOp>(loc, arg, j);
- builder.create<memref::StoreOp>(loc, vj, arg, i);
- builder.create<memref::StoreOp>(loc, vi, arg, j);
+ if (isLastDim == 1) {
+ // Finish checking all dimensions.
+ builder.create<scf::YieldOp>(loc, t);
}
- builder.setInsertionPointAfter(ifOp);
- builder.create<func::ReturnOp>(loc);
+ return ifOp;
+}
+
+/// Creates a function to compare whether xs[i] is equal to xs[j].
+//
+// The generate IR corresponds to this C like algorithm:
+// if (x0[i] != x0[j])
+// return false;
+// else
+// if (x1[i] != x1[j])
+// return false;
+// else if (x2[2] != x2[j]))
+// and so on ...
+static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
+ func::FuncOp func, size_t dim) {
+ createCompareFuncImplementation(builder, unused, func, dim, createEqCompare);
}
-/// Generates an if-statement to compare x[i] and x[j].
+/// Generates an if-statement to compare whether x[i] is less than x[j].
static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
Value i, Value j, Value x,
bool isLastDim) {
@@ -172,8 +225,7 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
return ifOp;
}
-/// Creates a function to compare the xs values in index i and j for all the
-/// dimensions. The function returns true iff xs[i] < xs[j].
+/// Creates a function to compare whether xs[i] is less than xs[j].
//
// The generate IR corresponds to this C like algorithm:
// if (x0[i] < x0[j])
@@ -187,29 +239,8 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
// and so on ...
static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
func::FuncOp func, size_t dim) {
- OpBuilder::InsertionGuard insertionGuard(builder);
-
- Block *entryBlock = func.addEntryBlock();
- builder.setInsertionPointToStart(entryBlock);
- Location loc = func.getLoc();
- ValueRange args = entryBlock->getArguments();
-
- scf::IfOp topIfOp;
- for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
- scf::IfOp ifOp =
- createLessThanCompare(builder, loc, args[0], args[1], item.value(),
- (item.index() == dim - 1));
- if (item.index() == 0) {
- topIfOp = ifOp;
- } else {
- OpBuilder::InsertionGuard insertionGuard(builder);
- builder.setInsertionPointAfter(ifOp);
- builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
- }
- }
-
- builder.setInsertionPointAfter(topIfOp);
- builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+ createCompareFuncImplementation(builder, unused, func, dim,
+ createLessThanCompare);
}
/// Creates a function to use a binary search to find the insertion point for
@@ -285,23 +316,94 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
}
+/// Creates code to advance i in a loop based on xs[p] as follows:
+/// while (xs[i] < xs[p]) i += step (step > 0)
+/// or
+/// while (xs[i] > xs[p]) i += step (step < 0)
+/// The routine returns i as well as a boolean value to indicate whether
+/// xs[i] == xs[p].
+static std::pair<Value, Value>
+createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
+ ValueRange xs, Value i, Value p, size_t dim, int step) {
+ Location loc = func.getLoc();
+ scf::WhileOp whileOp =
+ builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
+
+ Block *before =
+ builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
+ builder.setInsertionPointToEnd(before);
+ SmallVector<Value, 6> compareOperands;
+ if (step > 0) {
+ compareOperands.push_back(before->getArgument(0));
+ compareOperands.push_back(p);
+ } else {
+ assert(step < 0);
+ compareOperands.push_back(p);
+ compareOperands.push_back(before->getArgument(0));
+ }
+ compareOperands.append(xs.begin(), xs.end());
+ MLIRContext *context = module.getContext();
+ Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
+ FlatSymbolRefAttr lessThanFunc =
+ getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
+ dim, compareOperands, createLessThanFunc);
+ Value cond = builder
+ .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+ compareOperands)
+ .getResult(0);
+ builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
+
+ Block *after =
+ builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
+ builder.setInsertionPointToEnd(after);
+ Value cs = constantIndex(builder, loc, step);
+ i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
+ builder.create<scf::YieldOp>(loc, ValueRange{i});
+ i = whileOp.getResult(0);
+
+ builder.setInsertionPointAfter(whileOp);
+ compareOperands[0] = i;
+ compareOperands[1] = p;
+ FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
+ builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands,
+ createEqCompareFunc);
+ Value compareEq =
+ builder
+ .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
+ compareOperands)
+ .getResult(0);
+
+ return std::make_pair(whileOp.getResult(0), compareEq);
+}
+
/// Creates a function to perform quick sort partition on the values in the
/// range of index [lo, hi), assuming lo < hi.
//
// The generated IR corresponds to this C like algorithm:
-// int partition(lo, hi, data) {
-// pivot = data[hi - 1];
-// i = (lo – 1) // RHS of the pivot found so far.
-// for (j = lo; j < hi - 1; j++){
-// if (data[j] < pivot){
-// i++;
-// swap data[i] and data[j]
+// int partition(lo, hi, xs) {
+// p = (lo+hi)/2 // pivot index
+// i = lo
+// j = hi-1
+// while (i < j) do {
+// while (xs[i] < xs[p]) i ++;
+// i_eq = (xs[i] == xs[p]);
+// while (xs[j] > xs[p]) j --;
+// j_eq = (xs[j] == xs[p]);
+// if (i < j) {
+// swap(xs[i], xs[j])
+// if (i == p) {
+// p = j;
+// } else if (j == p) {
+// p = i;
+// }
+// if (i_eq && j_eq) {
+// ++i;
+// --j;
+// }
// }
// }
-// i++
-// swap data[i] and data[hi-1])
-// return i
-// }
+// return p
+// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, size_t dim) {
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -309,60 +411,96 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
- MLIRContext *context = module.getContext();
Location loc = func.getLoc();
ValueRange args = entryBlock->getArguments();
Value lo = args[loIdx];
+ Value hi = args[hiIdx];
+ Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
Value c1 = constantIndex(builder, loc, 1);
- Value i = builder.create<arith::SubIOp>(loc, lo, c1);
- Value him1 = builder.create<arith::SubIOp>(loc, args[hiIdx], c1);
- scf::ForOp forOp =
- builder.create<scf::ForOp>(loc, lo, him1, c1, ValueRange{i});
-
- // Start the for-stmt body.
- builder.setInsertionPointToStart(forOp.getBody());
- Value j = forOp.getInductionVar();
- SmallVector<Value, 6> compareOperands{j, him1};
- ValueRange xs = args.slice(xStartIdx, dim);
- compareOperands.append(xs.begin(), xs.end());
- Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
- FlatSymbolRefAttr lessThanFunc =
- getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
- dim, compareOperands, createLessThanFunc);
- Value cond = builder
- .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
- compareOperands)
- .getResult(0);
- scf::IfOp ifOp =
- builder.create<scf::IfOp>(loc, i.getType(), cond, /*else=*/true);
+ Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
+
+ Value i = lo;
+ Value j = builder.create<arith::SubIOp>(loc, hi, c1);
+ SmallVector<Value, 4> operands{i, j, p};
+ SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType()};
+ scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
+
+ // The before-region of the WhileOp.
+ Block *before =
+ builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
+ builder.setInsertionPointToEnd(before);
+ Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ before->getArgument(0),
+ before->getArgument(1));
+ builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
- // The if-stmt true branch: i++; swap(data[i], data[j]); yield i.
+ // The after-region of the WhileOp.
+ Block *after =
+ builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
+ builder.setInsertionPointToEnd(after);
+ i = after->getArgument(0);
+ j = after->getArgument(1);
+ p = after->getArgument(2);
+
+ auto [iresult, iCompareEq] = createScanLoop(
+ builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1);
+ i = iresult;
+ auto [jresult, jCompareEq] = createScanLoop(
+ builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1);
+ j = jresult;
+
+ // If i < j:
+ cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value i1 =
- builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
- SmallVector<Value, 6> swapOperands{i1, j};
+ SmallVector<Value, 6> swapOperands{i, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
- FlatSymbolRefAttr swapFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kMaySwapFuncNamePrefix, dim, swapOperands,
- createMaySwapFunc);
- builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
- builder.create<scf::YieldOp>(loc, i1);
-
- // The if-stmt false branch: yield i.
+ createSwap(builder, loc, swapOperands);
+ // If the pivot is moved, update p with the new pivot.
+ Value icond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
+ scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
+ icond, /*else=*/true);
+ builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{j});
+ builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
+ Value jcond =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
+ scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
+ jcond, /*else=*/true);
+ builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{i});
+ builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{p});
+ builder.setInsertionPointAfter(ifOpJ);
+ builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
+ builder.setInsertionPointAfter(ifOpI);
+ Value compareEqIJ =
+ builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
+ scf::IfOp ifOp2 = builder.create<scf::IfOp>(
+ loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
+ builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+ Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
+ Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
+ builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
+ builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{i, j});
+ builder.setInsertionPointAfter(ifOp2);
+ builder.create<scf::YieldOp>(
+ loc,
+ ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
+
+ // False branch for if i < j:
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<scf::YieldOp>(loc, forOp.getRegionIterArgs().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
- // After the if-stmt, yield the updated i value to end the for-stmt body.
+ // Return for the whileOp.
builder.setInsertionPointAfter(ifOp);
- builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
-
- // After the for-stmt: i++; swap(data[i], data[him1]); return i.
- builder.setInsertionPointAfter(forOp);
- i1 = builder.create<arith::AddIOp>(loc, forOp.getResult(0), c1);
- swapOperands[0] = i1;
- swapOperands[1] = him1;
- builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
- builder.create<func::ReturnOp>(loc, i1);
+ builder.create<scf::YieldOp>(loc, ifOp.getResults());
+
+ // Return for the function.
+ builder.setInsertionPointAfter(whileOp);
+ builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
}
/// Creates a function to perform quick sort on the value in the range of
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 114bfd874609f..f5634524f7e66 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -92,28 +92,14 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
// CHECK: return %[[C]]
// CHECK: }
-// CHECK-LABEL: func.func private @_sparse_may_swap_1_i8_f32_index(
-// CHECK-SAME: %[[I:arg0]]: index,
-// CHECK-SAME: %[[J:.*]]: index,
-// CHECK-SAME: %[[X0:.*]]: memref<?xi8>,
-// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
-// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) {
-// CHECK: %[[C:.*]] = arith.cmpi ne, %[[I]], %[[J]]
-// CHECK: scf.if %[[C]] {
-// CHECK: %[[Vx0i:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
-// CHECK: %[[Vx0j:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
-// CHECK: memref.store %[[Vx0j]], %[[X0]]{{\[}}%[[I]]]
-// CHECK: memref.store %[[Vx0i]], %[[X0]]{{\[}}%[[J]]]
-// CHECK: %[[Vy0i:.*]] = memref.load %[[Y0]]{{\[}}%[[I]]]
-// CHECK: %[[Vy0j:.*]] = memref.load %[[Y0]]{{\[}}%[[J]]]
-// CHECK: memref.store %[[Vy0j]], %[[Y0]]{{\[}}%[[I]]]
-// CHECK: memref.store %[[Vy0i]], %[[Y0]]{{\[}}%[[J]]]
-// CHECK: %[[Vy1i:.*]] = memref.load %[[Y1]]{{\[}}%[[I]]]
-// CHECK: %[[Vy1j:.*]] = memref.load %[[Y1]]{{\[}}%[[J]]]
-// CHECK: memref.store %[[Vy1j]], %[[Y1]]{{\[}}%[[I]]]
-// CHECK: memref.store %[[Vy1i]], %[[Y1]]{{\[}}%[[J]]]
-// CHECK: }
-// CHECK: return
+// CHECK-LABEL: func.func private @_sparse_compare_eq_1_i8(
+// CHECK-SAME: %[[I:arg0]]: index,
+// CHECK-SAME: %[[J:.*]]: index,
+// CHECK-SAME: %[[X0:.*]]: memref<?xi8>) -> i1 {
+// CHECK: %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
+// CHECK: %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
+// CHECK: %[[C:.*]] = arith.cmpi eq, %[[VI]], %[[VJ]]
+// CHECK: return %[[C]]
// CHECK: }
// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index(
@@ -123,22 +109,27 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
// CHECK-SAME: %[[Y0:.*]]: memref<?xf32>,
// CHECK-SAME: %[[Y1:.*]]: memref<?xindex>) -> index {
// CHECK: %[[C1:.*]] = arith.constant 1
-// CHECK: %[[I:.*]] = arith.subi %[[L]], %[[C1]]
-// CHECK: %[[Hm1:.*]] = arith.subi %[[H]], %[[C1]]
-// CHECK: %[[I3:.*]] = scf.for %[[J:.*]] = %[[L]] to %[[Hm1]] step %[[C1]] iter_args(%[[I2:.*]] = %[[I]]) -> (index) {
-// CHECK: %[[COND:.*]] = func.call @_sparse_less_than_1_i8(%[[J]], %[[Hm1]], %[[X0]])
-// CHECK: %[[IF:.*]] = scf.if %[[COND]] -> (index) {
-// CHECK: %[[Ip1:.*]] = arith.addi %[[I2]], %[[C1]]
-// CHECK: func.call @_sparse_may_swap_1_i8_f32_index(%[[Ip1]], %[[J]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK: scf.yield %[[Ip1]]
+// CHECK: %[[VAL_6:.*]] = arith.constant -
+// CHECK: %[[SUM:.*]] = arith.addi %[[L]], %[[H]]
+// CHECK: %[[P:.*]] = arith.shrui %[[SUM]], %[[C1]]
+// CHECK: %[[J:.*]] = arith.subi %[[H]], %[[C1]]
+// CHECK: %[[W:.*]]:3 = scf.while (%[[Ib:.*]] = %[[L]], %[[Jb:.*]] = %[[J]], %[[pb:.*]] = %[[P]]) : (index, index, index) -> (index, index, index) {
+// CHECK: %[[Cn:.*]] = arith.cmpi ult, %[[Ib]], %[[Jb]]
+// CHECK: scf.condition(%[[Cn]]) %[[Ib]], %[[Jb]], %[[pb]]
+// CHECK: } do {
+// CHECK: ^bb0(%[[Ia:.*]]: index, %[[Ja:.*]]: index, %[[Pa:.*]]: index):
+// CHECK: %[[I2:.*]] = scf.while
+// CHECK: %[[Ieq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[I2:.*]], %[[Pa]], %[[X0]])
+// CHECK: %[[J2:.*]] = scf.while
+// CHECK: %[[Jeq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[J2:.*]], %[[Pa]], %[[X0]])
+// CHECK: %[[Cn2:.*]] = arith.cmpi ult, %[[I2]], %[[J2]]
+// CHECK: %[[If:.*]]:3 = scf.if %[[Cn2]] -> (index, index, index) {
// CHECK: } else {
-// CHECK: scf.yield %[[I2]]
+// CHECK: scf.yield %[[I2]], %[[J2]], %[[Pa]]
// CHECK: }
-// CHECK: scf.yield %[[IF:.*]]
+// CHECK: scf.yield %[[If:.*]]#0, %[[If]]#1, %[[If]]#2
// CHECK: }
-// CHECK: %[[I3p1:.*]] = arith.addi %[[I3:.*]], %[[C1]] : index
-// CHECK: call @_sparse_may_swap_1_i8_f32_index(%[[I3p1]], %[[Hm1]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK: return %[[I3p1]]
+// CHECK: return %[[W:.*]]#2
// CHECK: }
// CHECK-LABEL: func.func private @_sparse_sort_nonstable_1_i8_f32_index(
@@ -181,7 +172,7 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
// to verify correctness of the generated code.
//
// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
-// CHECK-DAG: func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG: func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
// CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 650c0885fcb66..f0937e238af58 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -82,7 +82,7 @@ module {
// CHECK: ( 1, 1, 2, 5, 10 )
// CHECK: ( 3, 3, 1, 10, 1 )
// CHECK: ( 9, 9, 4, 7, 2 )
- // CHECK: ( 7, 8, 10, 9, 6 )
+ // CHECK: ( 8, 7, 10, 9, 6 )
call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
More information about the Mlir-commits
mailing list