[Mlir-commits] [mlir] 9b800bf - [mlir][sparse] Improve the non-stable sort implementation.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 7 07:38:48 PST 2022


Author: bixia1
Date: 2022-11-07T07:38:42-08:00
New Revision: 9b800bf79d9d2fa18ed5be891346155238015515

URL: https://github.com/llvm/llvm-project/commit/9b800bf79d9d2fa18ed5be891346155238015515
DIFF: https://github.com/llvm/llvm-project/commit/9b800bf79d9d2fa18ed5be891346155238015515.diff

LOG: [mlir][sparse] Improve the non-stable sort implementation.

Replace the quick sort partition method with one that is more similar to the
method used by C++ std quick sort. This improves the runtime for sorting
sk_2005.mtx by more than 10x.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137290

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 929d4a4ddf1f3..0af92a656d848 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -33,8 +33,8 @@ static constexpr uint64_t loIdx = 0;
 static constexpr uint64_t hiIdx = 1;
 static constexpr uint64_t xStartIdx = 2;
 
-static constexpr const char kMaySwapFuncNamePrefix[] = "_sparse_may_swap_";
 static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
+static constexpr const char kCompareEqFuncNamePrefix[] = "_sparse_compare_eq_";
 static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
 static constexpr const char kBinarySearchFuncNamePrefix[] =
     "_sparse_binary_search_";
@@ -90,11 +90,10 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
   return result;
 }
 
-/// Creates a function for swapping the values in index i and j for all the
+/// Creates a code block for swapping the values in index i and j for all the
 /// buffers.
 //
-// The generate IR corresponds to this C like algorithm:
-//   if (i != j) {
+// The generated IR corresponds to this C like algorithm:
 //     swap(x0[i], x0[j]);
 //     swap(x1[i], x1[j]);
 //     ...
@@ -102,36 +101,90 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
 //     swap(y0[i], y0[j]);
 //     ...
 //     swap(yn[i], yn[j]);
-//   }
-static void createMaySwapFunc(OpBuilder &builder, ModuleOp unused,
-                              func::FuncOp func, size_t dim) {
+static void createSwap(OpBuilder &builder, Location loc, ValueRange args) {
+  Value i = args[0];
+  Value j = args[1];
+  for (auto arg : args.drop_front(xStartIdx)) {
+    Value vi = builder.create<memref::LoadOp>(loc, arg, i);
+    Value vj = builder.create<memref::LoadOp>(loc, arg, j);
+    builder.create<memref::StoreOp>(loc, vj, arg, i);
+    builder.create<memref::StoreOp>(loc, vi, arg, j);
+  }
+}
+
+/// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
+/// compare each pair is create via `compareBuilder`.
+static void createCompareFuncImplementation(
+    OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim,
+    function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
+        compareBuilder) {
   OpBuilder::InsertionGuard insertionGuard(builder);
 
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
-
   Location loc = func.getLoc();
   ValueRange args = entryBlock->getArguments();
-  Value i = args[0];
-  Value j = args[1];
+
+  scf::IfOp topIfOp;
+  for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
+    scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1],
+                                    item.value(), (item.index() == dim - 1));
+    if (item.index() == 0) {
+      topIfOp = ifOp;
+    } else {
+      OpBuilder::InsertionGuard insertionGuard(builder);
+      builder.setInsertionPointAfter(ifOp);
+      builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
+    }
+  }
+
+  builder.setInsertionPointAfter(topIfOp);
+  builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+}
+
+/// Generates an if-statement to compare whether x[i] is equal to x[j].
+static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
+                                 Value j, Value x, bool isLastDim) {
+  Value f = constantI1(builder, loc, false);
+  Value t = constantI1(builder, loc, true);
+  Value vi = builder.create<memref::LoadOp>(loc, x, i);
+  Value vj = builder.create<memref::LoadOp>(loc, x, j);
+
   Value cond =
-      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, i, j);
-  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, vi, vj);
+  scf::IfOp ifOp =
+      builder.create<scf::IfOp>(loc, f.getType(), cond, /*else=*/true);
 
-  // If i!=j swap values in the buffers.
+  // x[1] != x[j]:
+  builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+  builder.create<scf::YieldOp>(loc, f);
+
+  // x[i] == x[j]:
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  for (auto arg : args.drop_front(xStartIdx)) {
-    Value vi = builder.create<memref::LoadOp>(loc, arg, i);
-    Value vj = builder.create<memref::LoadOp>(loc, arg, j);
-    builder.create<memref::StoreOp>(loc, vj, arg, i);
-    builder.create<memref::StoreOp>(loc, vi, arg, j);
+  if (isLastDim == 1) {
+    // Finish checking all dimensions.
+    builder.create<scf::YieldOp>(loc, t);
   }
 
-  builder.setInsertionPointAfter(ifOp);
-  builder.create<func::ReturnOp>(loc);
+  return ifOp;
+}
+
+/// Creates a function to compare whether xs[i] is equal to xs[j].
+//
+// The generate IR corresponds to this C like algorithm:
+//   if (x0[i] != x0[j])
+//     return false;
+//   else
+//     if (x1[i] != x1[j])
+//       return false;
+//     else if (x2[2] != x2[j]))
+//       and so on ...
+static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
+                                func::FuncOp func, size_t dim) {
+  createCompareFuncImplementation(builder, unused, func, dim, createEqCompare);
 }
 
-/// Generates an if-statement to compare x[i] and x[j].
+/// Generates an if-statement to compare whether x[i] is less than x[j].
 static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
                                        Value i, Value j, Value x,
                                        bool isLastDim) {
@@ -172,8 +225,7 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
   return ifOp;
 }
 
-/// Creates a function to compare the xs values in index i and j for all the
-/// dimensions. The function returns true iff xs[i] < xs[j].
+/// Creates a function to compare whether xs[i] is less than xs[j].
 //
 // The generate IR corresponds to this C like algorithm:
 //   if (x0[i] < x0[j])
@@ -187,29 +239,8 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
 //       and so on ...
 static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
                                func::FuncOp func, size_t dim) {
-  OpBuilder::InsertionGuard insertionGuard(builder);
-
-  Block *entryBlock = func.addEntryBlock();
-  builder.setInsertionPointToStart(entryBlock);
-  Location loc = func.getLoc();
-  ValueRange args = entryBlock->getArguments();
-
-  scf::IfOp topIfOp;
-  for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
-    scf::IfOp ifOp =
-        createLessThanCompare(builder, loc, args[0], args[1], item.value(),
-                              (item.index() == dim - 1));
-    if (item.index() == 0) {
-      topIfOp = ifOp;
-    } else {
-      OpBuilder::InsertionGuard insertionGuard(builder);
-      builder.setInsertionPointAfter(ifOp);
-      builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
-    }
-  }
-
-  builder.setInsertionPointAfter(topIfOp);
-  builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
+  createCompareFuncImplementation(builder, unused, func, dim,
+                                  createLessThanCompare);
 }
 
 /// Creates a function to use a binary search to find the insertion point for
@@ -285,23 +316,94 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
   builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
 }
 
+/// Creates code to advance i in a loop based on xs[p] as follows:
+///   while (xs[i] < xs[p]) i += step (step > 0)
+/// or
+///   while (xs[i] > xs[p]) i += step (step < 0)
+/// The routine returns i as well as a boolean value to indicate whether
+/// xs[i] == xs[p].
+static std::pair<Value, Value>
+createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
+               ValueRange xs, Value i, Value p, size_t dim, int step) {
+  Location loc = func.getLoc();
+  scf::WhileOp whileOp =
+      builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
+
+  Block *before =
+      builder.createBlock(&whileOp.getBefore(), {}, {i.getType()}, {loc});
+  builder.setInsertionPointToEnd(before);
+  SmallVector<Value, 6> compareOperands;
+  if (step > 0) {
+    compareOperands.push_back(before->getArgument(0));
+    compareOperands.push_back(p);
+  } else {
+    assert(step < 0);
+    compareOperands.push_back(p);
+    compareOperands.push_back(before->getArgument(0));
+  }
+  compareOperands.append(xs.begin(), xs.end());
+  MLIRContext *context = module.getContext();
+  Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
+  FlatSymbolRefAttr lessThanFunc =
+      getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
+                               dim, compareOperands, createLessThanFunc);
+  Value cond = builder
+                   .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+                                         compareOperands)
+                   .getResult(0);
+  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
+
+  Block *after =
+      builder.createBlock(&whileOp.getAfter(), {}, {i.getType()}, {loc});
+  builder.setInsertionPointToEnd(after);
+  Value cs = constantIndex(builder, loc, step);
+  i = builder.create<arith::AddIOp>(loc, after->getArgument(0), cs);
+  builder.create<scf::YieldOp>(loc, ValueRange{i});
+  i = whileOp.getResult(0);
+
+  builder.setInsertionPointAfter(whileOp);
+  compareOperands[0] = i;
+  compareOperands[1] = p;
+  FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
+      builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands,
+      createEqCompareFunc);
+  Value compareEq =
+      builder
+          .create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
+                                compareOperands)
+          .getResult(0);
+
+  return std::make_pair(whileOp.getResult(0), compareEq);
+}
+
 /// Creates a function to perform quick sort partition on the values in the
 /// range of index [lo, hi), assuming lo < hi.
 //
 // The generated IR corresponds to this C like algorithm:
-// int partition(lo, hi, data) {
-//   pivot = data[hi - 1];
-//   i = (lo – 1)  // RHS of the pivot found so far.
-//   for (j = lo; j < hi - 1; j++){
-//     if (data[j] < pivot){
-//       i++;
-//       swap data[i] and data[j]
+// int partition(lo, hi, xs) {
+//   p = (lo+hi)/2  // pivot index
+//   i = lo
+//   j = hi-1
+//   while (i < j) do {
+//     while (xs[i] < xs[p]) i ++;
+//     i_eq = (xs[i] == xs[p]);
+//     while (xs[j] > xs[p]) j --;
+//     j_eq = (xs[j] == xs[p]);
+//     if (i < j) {
+//       swap(xs[i], xs[j])
+//       if (i == p) {
+//         p = j;
+//       } else if (j == p) {
+//         p = i;
+//       }
+//       if (i_eq && j_eq) {
+//         ++i;
+//         --j;
+//       }
 //     }
 //   }
-//   i++
-//   swap data[i] and data[hi-1])
-//   return i
-// }
+//   return p
+//   }
 static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
                                 func::FuncOp func, size_t dim) {
   OpBuilder::InsertionGuard insertionGuard(builder);
@@ -309,60 +411,96 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
 
-  MLIRContext *context = module.getContext();
   Location loc = func.getLoc();
   ValueRange args = entryBlock->getArguments();
   Value lo = args[loIdx];
+  Value hi = args[hiIdx];
+  Value sum = builder.create<arith::AddIOp>(loc, lo, hi);
   Value c1 = constantIndex(builder, loc, 1);
-  Value i = builder.create<arith::SubIOp>(loc, lo, c1);
-  Value him1 = builder.create<arith::SubIOp>(loc, args[hiIdx], c1);
-  scf::ForOp forOp =
-      builder.create<scf::ForOp>(loc, lo, him1, c1, ValueRange{i});
-
-  // Start the for-stmt body.
-  builder.setInsertionPointToStart(forOp.getBody());
-  Value j = forOp.getInductionVar();
-  SmallVector<Value, 6> compareOperands{j, him1};
-  ValueRange xs = args.slice(xStartIdx, dim);
-  compareOperands.append(xs.begin(), xs.end());
-  Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
-  FlatSymbolRefAttr lessThanFunc =
-      getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
-                               dim, compareOperands, createLessThanFunc);
-  Value cond = builder
-                   .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
-                                         compareOperands)
-                   .getResult(0);
-  scf::IfOp ifOp =
-      builder.create<scf::IfOp>(loc, i.getType(), cond, /*else=*/true);
+  Value p = builder.create<arith::ShRUIOp>(loc, sum, c1);
+
+  Value i = lo;
+  Value j = builder.create<arith::SubIOp>(loc, hi, c1);
+  SmallVector<Value, 4> operands{i, j, p};
+  SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType()};
+  scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands);
+
+  // The before-region of the WhileOp.
+  Block *before =
+      builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc, loc});
+  builder.setInsertionPointToEnd(before);
+  Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                             before->getArgument(0),
+                                             before->getArgument(1));
+  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
 
-  // The if-stmt true branch: i++; swap(data[i], data[j]); yield i.
+  // The after-region of the WhileOp.
+  Block *after =
+      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc, loc});
+  builder.setInsertionPointToEnd(after);
+  i = after->getArgument(0);
+  j = after->getArgument(1);
+  p = after->getArgument(2);
+
+  auto [iresult, iCompareEq] = createScanLoop(
+      builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1);
+  i = iresult;
+  auto [jresult, jCompareEq] = createScanLoop(
+      builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1);
+  j = jresult;
+
+  // If i < j:
+  cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, i, j);
+  scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true);
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
-  Value i1 =
-      builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
-  SmallVector<Value, 6> swapOperands{i1, j};
+  SmallVector<Value, 6> swapOperands{i, j};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  FlatSymbolRefAttr swapFunc = getMangledSortHelperFunc(
-      builder, func, TypeRange(), kMaySwapFuncNamePrefix, dim, swapOperands,
-      createMaySwapFunc);
-  builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
-  builder.create<scf::YieldOp>(loc, i1);
-
-  // The if-stmt false branch: yield i.
+  createSwap(builder, loc, swapOperands);
+  // If the pivot is moved, update p with the new pivot.
+  Value icond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
+  scf::IfOp ifOpI = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
+                                              icond, /*else=*/true);
+  builder.setInsertionPointToStart(&ifOpI.getThenRegion().front());
+  builder.create<scf::YieldOp>(loc, ValueRange{j});
+  builder.setInsertionPointToStart(&ifOpI.getElseRegion().front());
+  Value jcond =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, j, p);
+  scf::IfOp ifOpJ = builder.create<scf::IfOp>(loc, TypeRange{p.getType()},
+                                              jcond, /*else=*/true);
+  builder.setInsertionPointToStart(&ifOpJ.getThenRegion().front());
+  builder.create<scf::YieldOp>(loc, ValueRange{i});
+  builder.setInsertionPointToStart(&ifOpJ.getElseRegion().front());
+  builder.create<scf::YieldOp>(loc, ValueRange{p});
+  builder.setInsertionPointAfter(ifOpJ);
+  builder.create<scf::YieldOp>(loc, ifOpJ.getResults());
+  builder.setInsertionPointAfter(ifOpI);
+  Value compareEqIJ =
+      builder.create<arith::AndIOp>(loc, iCompareEq, jCompareEq);
+  scf::IfOp ifOp2 = builder.create<scf::IfOp>(
+      loc, TypeRange{i.getType(), j.getType()}, compareEqIJ, /*else=*/true);
+  builder.setInsertionPointToStart(&ifOp2.getThenRegion().front());
+  Value i2 = builder.create<arith::AddIOp>(loc, i, c1);
+  Value j2 = builder.create<arith::SubIOp>(loc, j, c1);
+  builder.create<scf::YieldOp>(loc, ValueRange{i2, j2});
+  builder.setInsertionPointToStart(&ifOp2.getElseRegion().front());
+  builder.create<scf::YieldOp>(loc, ValueRange{i, j});
+  builder.setInsertionPointAfter(ifOp2);
+  builder.create<scf::YieldOp>(
+      loc,
+      ValueRange{ifOp2.getResult(0), ifOp2.getResult(1), ifOpI.getResult(0)});
+
+  // False branch for if i < j:
   builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-  builder.create<scf::YieldOp>(loc, forOp.getRegionIterArgs().front());
+  builder.create<scf::YieldOp>(loc, ValueRange{i, j, p});
 
-  // After the if-stmt, yield the updated i value to end the for-stmt body.
+  // Return for the whileOp.
   builder.setInsertionPointAfter(ifOp);
-  builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
-
-  // After the for-stmt: i++; swap(data[i], data[him1]); return i.
-  builder.setInsertionPointAfter(forOp);
-  i1 = builder.create<arith::AddIOp>(loc, forOp.getResult(0), c1);
-  swapOperands[0] = i1;
-  swapOperands[1] = him1;
-  builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
-  builder.create<func::ReturnOp>(loc, i1);
+  builder.create<scf::YieldOp>(loc, ifOp.getResults());
+
+  // Return for the function.
+  builder.setInsertionPointAfter(whileOp);
+  builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
 }
 
 /// Creates a function to perform quick sort on the value in the range of

diff  --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 114bfd874609f..f5634524f7e66 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -92,28 +92,14 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
 // CHECK:           return %[[C]]
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_sparse_may_swap_1_i8_f32_index(
-// CHECK-SAME:                                                       %[[I:arg0]]: index,
-// CHECK-SAME:                                                       %[[J:.*]]: index,
-// CHECK-SAME:                                                       %[[X0:.*]]: memref<?xi8>,
-// CHECK-SAME:                                                       %[[Y0:.*]]: memref<?xf32>,
-// CHECK-SAME:                                                       %[[Y1:.*]]: memref<?xindex>) {
-// CHECK:           %[[C:.*]] = arith.cmpi ne, %[[I]], %[[J]]
-// CHECK:           scf.if %[[C]] {
-// CHECK:             %[[Vx0i:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
-// CHECK:             %[[Vx0j:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
-// CHECK:             memref.store %[[Vx0j]], %[[X0]]{{\[}}%[[I]]]
-// CHECK:             memref.store %[[Vx0i]], %[[X0]]{{\[}}%[[J]]]
-// CHECK:             %[[Vy0i:.*]] = memref.load %[[Y0]]{{\[}}%[[I]]]
-// CHECK:             %[[Vy0j:.*]] = memref.load %[[Y0]]{{\[}}%[[J]]]
-// CHECK:             memref.store %[[Vy0j]], %[[Y0]]{{\[}}%[[I]]]
-// CHECK:             memref.store %[[Vy0i]], %[[Y0]]{{\[}}%[[J]]]
-// CHECK:             %[[Vy1i:.*]] = memref.load %[[Y1]]{{\[}}%[[I]]]
-// CHECK:             %[[Vy1j:.*]] = memref.load %[[Y1]]{{\[}}%[[J]]]
-// CHECK:             memref.store %[[Vy1j]], %[[Y1]]{{\[}}%[[I]]]
-// CHECK:             memref.store %[[Vy1i]], %[[Y1]]{{\[}}%[[J]]]
-// CHECK:           }
-// CHECK:           return
+// CHECK-LABEL:   func.func private @_sparse_compare_eq_1_i8(
+// CHECK-SAME:                                               %[[I:arg0]]: index,
+// CHECK-SAME:                                               %[[J:.*]]: index,
+// CHECK-SAME:                                               %[[X0:.*]]: memref<?xi8>) -> i1 {
+// CHECK:           %[[VI:.*]] = memref.load %[[X0]]{{\[}}%[[I]]]
+// CHECK:           %[[VJ:.*]] = memref.load %[[X0]]{{\[}}%[[J]]]
+// CHECK:           %[[C:.*]] = arith.cmpi eq, %[[VI]], %[[VJ]]
+// CHECK:           return %[[C]]
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_sparse_partition_1_i8_f32_index(
@@ -123,22 +109,27 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
 // CHECK-SAME:                                                        %[[Y0:.*]]: memref<?xf32>,
 // CHECK-SAME:                                                        %[[Y1:.*]]: memref<?xindex>) -> index {
 // CHECK:           %[[C1:.*]] = arith.constant 1
-// CHECK:           %[[I:.*]] = arith.subi %[[L]], %[[C1]]
-// CHECK:           %[[Hm1:.*]] = arith.subi %[[H]], %[[C1]]
-// CHECK:           %[[I3:.*]] = scf.for %[[J:.*]] = %[[L]] to %[[Hm1]] step %[[C1]] iter_args(%[[I2:.*]] = %[[I]]) -> (index) {
-// CHECK:             %[[COND:.*]] = func.call @_sparse_less_than_1_i8(%[[J]], %[[Hm1]], %[[X0]])
-// CHECK:             %[[IF:.*]] = scf.if %[[COND]] -> (index) {
-// CHECK:               %[[Ip1:.*]] = arith.addi %[[I2]], %[[C1]]
-// CHECK:               func.call @_sparse_may_swap_1_i8_f32_index(%[[Ip1]], %[[J]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:               scf.yield %[[Ip1]]
+// CHECK:           %[[VAL_6:.*]] = arith.constant -
+// CHECK:           %[[SUM:.*]] = arith.addi %[[L]], %[[H]]
+// CHECK:           %[[P:.*]] = arith.shrui %[[SUM]], %[[C1]]
+// CHECK:           %[[J:.*]] = arith.subi %[[H]], %[[C1]]
+// CHECK:           %[[W:.*]]:3 = scf.while (%[[Ib:.*]] = %[[L]], %[[Jb:.*]] = %[[J]], %[[pb:.*]] = %[[P]]) : (index, index, index) -> (index, index, index) {
+// CHECK:             %[[Cn:.*]] = arith.cmpi ult, %[[Ib]], %[[Jb]]
+// CHECK:             scf.condition(%[[Cn]]) %[[Ib]], %[[Jb]], %[[pb]]
+// CHECK:           } do {
+// CHECK:           ^bb0(%[[Ia:.*]]: index, %[[Ja:.*]]: index, %[[Pa:.*]]: index):
+// CHECK:             %[[I2:.*]] = scf.while
+// CHECK:             %[[Ieq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[I2:.*]], %[[Pa]], %[[X0]])
+// CHECK:             %[[J2:.*]] = scf.while
+// CHECK:             %[[Jeq:.*]] = func.call @_sparse_compare_eq_1_i8(%[[J2:.*]], %[[Pa]], %[[X0]])
+// CHECK:             %[[Cn2:.*]] = arith.cmpi ult, %[[I2]], %[[J2]]
+// CHECK:             %[[If:.*]]:3 = scf.if %[[Cn2]] -> (index, index, index) {
 // CHECK:             } else {
-// CHECK:               scf.yield %[[I2]]
+// CHECK:               scf.yield %[[I2]], %[[J2]], %[[Pa]]
 // CHECK:             }
-// CHECK:             scf.yield %[[IF:.*]]
+// CHECK:             scf.yield %[[If:.*]]#0, %[[If]]#1, %[[If]]#2
 // CHECK:           }
-// CHECK:           %[[I3p1:.*]] = arith.addi %[[I3:.*]], %[[C1]] : index
-// CHECK:           call @_sparse_may_swap_1_i8_f32_index(%[[I3p1]], %[[Hm1]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:           return %[[I3p1]]
+// CHECK:           return %[[W:.*]]#2
 // CHECK:         }
 
 // CHECK-LABEL:   func.func private @_sparse_sort_nonstable_1_i8_f32_index(
@@ -181,7 +172,7 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
 // to verify correctness of the generated code.
 //
 // CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
-// CHECK-DAG:     func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG:     func.func private @_sparse_compare_eq_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
 // CHECK-DAG:     func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 650c0885fcb66..f0937e238af58 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -82,7 +82,7 @@ module {
     // CHECK: ( 1, 1, 2, 5, 10 )
     // CHECK: ( 3, 3, 1, 10, 1 )
     // CHECK: ( 9, 9, 4, 7, 2 )
-    // CHECK: ( 7, 8, 10, 9, 6 )
+    // CHECK: ( 8, 7, 10, 9, 6 )
     call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)


        


More information about the Mlir-commits mailing list