[Mlir-commits] [mlir] 9409bbb - [mlir][sparse] Implement insertion sort for the stable sort operator.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Oct 6 09:48:49 PDT 2022


Author: bixia1
Date: 2022-10-06T09:48:39-07:00
New Revision: 9409bbb2e0527336f8a6f163b1a5a791f4556e95

URL: https://github.com/llvm/llvm-project/commit/9409bbb2e0527336f8a6f163b1a5a791f4556e95
DIFF: https://github.com/llvm/llvm-project/commit/9409bbb2e0527336f8a6f163b1a5a791f4556e95.diff

LOG: [mlir][sparse] Implement insertion sort for the stable sort operator.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135182

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 8741f5f7d89b3..758d85cace118 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -28,9 +28,19 @@ using namespace mlir::sparse_tensor;
 // Helper methods for the actual rewriting rules.
 //===---------------------------------------------------------------------===//
 
-constexpr uint64_t loIdx = 0;
-constexpr uint64_t hiIdx = 1;
-constexpr uint64_t xStartIdx = 2;
+static constexpr uint64_t loIdx = 0;
+static constexpr uint64_t hiIdx = 1;
+static constexpr uint64_t xStartIdx = 2;
+
+static constexpr const char kMaySwapFuncNamePrefix[] = "_sparse_may_swap_";
+static constexpr const char kLessThanFuncNamePrefix[] = "_sparse_less_than_";
+static constexpr const char kPartitionFuncNamePrefix[] = "_sparse_partition_";
+static constexpr const char kBinarySearchFuncNamePrefix[] =
+    "_sparse_binary_search_";
+static constexpr const char kSortNonstableFuncNamePrefix[] =
+    "_sparse_sort_nonstable_";
+static constexpr const char kSortStableFuncNamePrefix[] =
+    "_sparse_sort_stable_";
 
 typedef function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>
     FuncGeneratorType;
@@ -201,6 +211,79 @@ static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
   builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
 }
 
+/// Creates a function to use a binary search to find the insertion point for
+/// inserting xs[hi] to the sorted values xs[lo..hi).
+//
+// The generate IR corresponds to this C like algorithm:
+//   p = hi
+//   while (lo < hi)
+//      mid = (lo + hi) >> 1
+//      if (xs[p] < xs[mid])
+//        hi = mid
+//      else
+//        lo = mid - 1
+//   return lo;
+//
+static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
+                                   func::FuncOp func, size_t dim) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value p = args[hiIdx];
+  SmallVector<Type, 2> types(2, p.getType());
+  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
+      loc, types, SmallVector<Value, 2>{args[loIdx], args[hiIdx]});
+
+  // The before-region of the WhileOp.
+  Block *before =
+      builder.createBlock(&whileOp.getBefore(), {}, types, {loc, loc});
+  builder.setInsertionPointToEnd(before);
+  Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                              before->getArgument(0),
+                                              before->getArgument(1));
+  builder.create<scf::ConditionOp>(loc, cond1, before->getArguments());
+
+  // The after-region of the WhileOp.
+  Block *after =
+      builder.createBlock(&whileOp.getAfter(), {}, types, {loc, loc});
+  builder.setInsertionPointToEnd(after);
+  Value lo = after->getArgument(0);
+  Value hi = after->getArgument(1);
+  // Compute mid = (lo + hi) >> 1.
+  Value c1 = constantIndex(builder, loc, 1);
+  Value mid = builder.create<arith::ShRUIOp>(
+      loc, builder.create<arith::AddIOp>(loc, lo, hi), c1);
+  Value midp1 = builder.create<arith::AddIOp>(loc, mid, c1);
+
+  // Compare xs[p] < xs[mid].
+  SmallVector<Value, 6> compareOperands{p, mid};
+  compareOperands.append(args.begin() + xStartIdx,
+                         args.begin() + xStartIdx + dim);
+  Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
+  FlatSymbolRefAttr lessThanFunc =
+      getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
+                               dim, compareOperands, createLessThanFunc);
+  Value cond2 = builder
+                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+                                          compareOperands)
+                    .getResult(0);
+
+  // Update lo and hi for the WhileOp as follows:
+  //   if (xs[p] < xs[mid]))
+  //     hi = mid;
+  //   else
+  //     lo = mid + 1;
+  Value newLo = builder.create<arith::SelectOp>(loc, cond2, lo, midp1);
+  Value newHi = builder.create<arith::SelectOp>(loc, cond2, mid, hi);
+  builder.create<scf::YieldOp>(loc, ValueRange{newLo, newHi});
+
+  builder.setInsertionPointAfter(whileOp);
+  builder.create<func::ReturnOp>(loc, whileOp.getResult(0));
+}
+
 /// Creates a function to perform quick sort partition on the values in the
 /// range of index [lo, hi), assuming lo < hi.
 //
@@ -243,7 +326,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   compareOperands.append(xs.begin(), xs.end());
   Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
   FlatSymbolRefAttr lessThanFunc =
-      getMangledSortHelperFunc(builder, func, {i1Type}, "_sparse_less_than_",
+      getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
                                dim, compareOperands, createLessThanFunc);
   Value cond = builder
                    .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
@@ -258,9 +341,9 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
       builder.create<arith::AddIOp>(loc, forOp.getRegionIterArgs().front(), c1);
   SmallVector<Value, 6> swapOperands{i1, j};
   swapOperands.append(args.begin() + xStartIdx, args.end());
-  FlatSymbolRefAttr swapFunc =
-      getMangledSortHelperFunc(builder, func, TypeRange(), "_sparse_may_swap_",
-                               dim, swapOperands, createMaySwapFunc);
+  FlatSymbolRefAttr swapFunc = getMangledSortHelperFunc(
+      builder, func, TypeRange(), kMaySwapFuncNamePrefix, dim, swapOperands,
+      createMaySwapFunc);
   builder.create<func::CallOp>(loc, swapFunc, TypeRange(), swapOperands);
   builder.create<scf::YieldOp>(loc, i1);
 
@@ -292,8 +375,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
 //        quickSort(p + 1, hi, data);
 //   }
 // }
-static void createSortFunc(OpBuilder &builder, ModuleOp module,
-                           func::FuncOp func, size_t dim) {
+static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
+                                    func::FuncOp func, size_t dim) {
   OpBuilder::InsertionGuard insertionGuard(builder);
   Block *entryBlock = func.addEntryBlock();
   builder.setInsertionPointToStart(entryBlock);
@@ -310,8 +393,8 @@ static void createSortFunc(OpBuilder &builder, ModuleOp module,
   // The if-stmt true branch.
   builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
   FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
-      builder, func, {IndexType::get(context)}, "_sparse_partition_", dim, args,
-      createPartitionFunc);
+      builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim,
+      args, createPartitionFunc);
   auto p = builder.create<func::CallOp>(
       loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
 
@@ -331,6 +414,78 @@ static void createSortFunc(OpBuilder &builder, ModuleOp module,
   builder.create<func::ReturnOp>(loc);
 }
 
+/// Creates a function to perform insertion sort on the values in the range of
+/// index [lo, hi).
+//
+// The generate IR corresponds to this C like algorithm:
+// void insertionSort(lo, hi, data) {
+//   for (i = lo+1; i < hi; i++) {
+//      d = data[i];
+//      p = binarySearch(lo, i-1, data)
+//      for (j = 0; j > i - p; j++)
+//        data[i-j] = data[i-j-1]
+//      data[p] = d
+//   }
+// }
+static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
+                                 func::FuncOp func, size_t dim) {
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  MLIRContext *context = module.getContext();
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value c1 = constantIndex(builder, loc, 1);
+  Value lo = args[loIdx];
+  Value hi = args[hiIdx];
+  Value lop1 = builder.create<arith::AddIOp>(loc, lo, c1);
+
+  // Start the outer for-stmt with induction variable i.
+  scf::ForOp forOpI = builder.create<scf::ForOp>(loc, lop1, hi, c1);
+  builder.setInsertionPointToStart(forOpI.getBody());
+  Value i = forOpI.getInductionVar();
+
+  // Binary search to find the insertion point p.
+  SmallVector<Value, 6> operands{lo, i};
+  operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim);
+  FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
+      builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
+      dim, operands, createBinarySearchFunc);
+  Value p = builder
+                .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
+                                      operands)
+                .getResult(0);
+
+  // Move the value at data[i] to a temporary location.
+  ValueRange data = args.drop_front(xStartIdx);
+  SmallVector<Value, 6> d;
+  for (Value v : data)
+    d.push_back(builder.create<memref::LoadOp>(loc, v, i));
+
+  // Start the inner for-stmt with induction variable j, for moving data[p..i)
+  // to data[p+1..i+1).
+  Value imp = builder.create<arith::SubIOp>(loc, i, p);
+  Value c0 = constantIndex(builder, loc, 0);
+  scf::ForOp forOpJ = builder.create<scf::ForOp>(loc, c0, imp, c1);
+  builder.setInsertionPointToStart(forOpJ.getBody());
+  Value j = forOpJ.getInductionVar();
+  Value imj = builder.create<arith::SubIOp>(loc, i, j);
+  Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1);
+  for (Value v : data) {
+    Value t = builder.create<memref::LoadOp>(loc, v, imjm1);
+    builder.create<memref::StoreOp>(loc, t, v, imj);
+  }
+
+  // Store the value at data[i] to data[p].
+  builder.setInsertionPointAfter(forOpJ);
+  for (auto it : llvm::zip(d, data))
+    builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p);
+
+  builder.setInsertionPointAfter(forOpI);
+  builder.create<func::ReturnOp>(loc);
+}
+
 //===---------------------------------------------------------------------===//
 // The actual sparse buffer rewriting rules.
 //===---------------------------------------------------------------------===//
@@ -425,9 +580,13 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
     addValues(xs);
     addValues(op.getYs());
     auto insertPoint = op->getParentOfType<func::FuncOp>();
-    FlatSymbolRefAttr func = getMangledSortHelperFunc(
-        rewriter, insertPoint, TypeRange(), "_sparse_sort_", xs.size(),
-        operands, createSortFunc);
+    SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
+                                            : kSortNonstableFuncNamePrefix);
+    FuncGeneratorType funcGenerator =
+        op.getStable() ? createSortStableFunc : createSortNonstableFunc;
+    FlatSymbolRefAttr func =
+        getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
+                                 xs.size(), operands, funcGenerator);
     rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
     return success();
   }

diff  --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index ccf9d40f59de2..dd81c523ce329 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --sparse-buffer-rewrite  --canonicalize --cse | FileCheck %s
+// RUN: mlir-opt %s -split-input-file --sparse-buffer-rewrite  --canonicalize --cse | FileCheck %s
 
 // CHECK-LABEL: func @sparse_push_back(
 //  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
@@ -26,6 +26,8 @@ func.func @sparse_push_back(%arg0: memref<?xindex>, %arg1: memref<?xf64>, %arg2:
   return %0 : memref<?xf64>
 }
 
+// -----
+
 // CHECK-LABEL: func @sparse_push_back_inbound(
 //  CHECK-SAME: %[[A:.*]]: memref<?xindex>,
 //  CHECK-SAME: %[[B:.*]]: memref<?xf64>,
@@ -42,6 +44,8 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
   return %0 : memref<?xf64>
 }
 
+// -----
+
 // CHECK-LABEL:   func.func private @_sparse_less_than_1_i8(
 // CHECK-SAME:                                              %[[I:arg0]]: index,
 // CHECK-SAME:                                              %[[J:.*]]: index,
@@ -101,7 +105,7 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
 // CHECK:           return %[[I3p1]]
 // CHECK:         }
 
-// CHECK-LABEL:   func.func private @_sparse_sort_1_i8_f32_index(
+// CHECK-LABEL:   func.func private @_sparse_sort_nonstable_1_i8_f32_index(
 // CHECK-SAME:                                                   %[[L:arg0]]: index,
 // CHECK-SAME:                                                   %[[H:.*]]: index,
 // CHECK-SAME:                                                   %[[X0:.*]]: memref<?xi8>,
@@ -111,9 +115,9 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
 // CHECK:           %[[COND:.*]] = arith.cmpi ult, %[[L]], %[[H]]
 // CHECK:           scf.if %[[COND]] {
 // CHECK:             %[[P:.*]] = func.call @_sparse_partition_1_i8_f32_index(%[[L]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
-// CHECK:             func.call @_sparse_sort_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[L]], %[[P]], %[[X0]], %[[Y0]], %[[Y1]])
 // CHECK:             %[[P2:.*]] = arith.addi %[[P]], %[[C1]] : index
-// CHECK:             func.call @_sparse_sort_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
+// CHECK:             func.call @_sparse_sort_nonstable_1_i8_f32_index(%[[P2]], %[[H]], %[[X0]], %[[Y0]], %[[Y1]])
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -126,7 +130,7 @@ func.func @sparse_push_back_inbound(%arg0: memref<?xindex>, %arg1: memref<?xf64>
 // CHECK:           %[[C0:.*]] = arith.constant 0
 // CHECK:           %[[DX0:.*]] = memref.cast %[[X0]] : memref<10xi8> to memref<?xi8>
 // CHECK:           %[[DY1:.*]] = memref.cast %[[Y1]] : memref<10xindex> to memref<?xindex>
-// CHECK:           call @_sparse_sort_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
+// CHECK:           call @_sparse_sort_nonstable_1_i8_f32_index(%[[C0]], %[[N]], %[[DX0]], %[[Y0]], %[[DY1]])
 // CHECK:           return %[[X0]], %[[Y0]], %[[Y1]]
 // CHECK:         }
 func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
@@ -135,15 +139,31 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
 }
 
+// -----
+
 // Only check the generated supporting function now. We have integration test
 // to verify correctness of the generated code.
 //
 // CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
 // CHECK-DAG:     func.func private @_sparse_may_swap_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-DAG:     func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG:     func.func private @_sparse_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-DAG:     func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
 // CHECK-LABEL:   func.func @sparse_sort_3d
 func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
   sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
   return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
 }
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
+// CHECK-DAG:     func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL:   func.func @sparse_sort_3d_stable
+func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+  sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 4db44ad58a3e9..650c0885fcb66 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -51,12 +51,24 @@ module {
     sparse_tensor.sort %i0, %x0 : memref<?xi32>
     %x0v0 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %x0v0 : vector<5xi32>
+    // Stable sort.
+    // CHECK: ( 10, 2, 0, 5, 1 )
+    sparse_tensor.sort stable %i0, %x0 : memref<?xi32>
+    %x0v0s = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x0v0s : vector<5xi32>
 
     // Sort the first 4 elements, with the last valid value untouched.
     // CHECK: ( 0, 2, 5, 10, 1 )
     sparse_tensor.sort %i4, %x0 : memref<?xi32>
     %x0v1 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %x0v1 : vector<5xi32>
+    // Stable sort.
+    // CHECK: ( 0, 2, 5, 10, 1 )
+    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort stable %i4, %x0 : memref<?xi32>
+    %x0v1s = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x0v1s : vector<5xi32>
 
     // Prepare more buffers of 
diff erent dimensions.
     %x1s = memref.alloc() : memref<10xi32>
@@ -65,20 +77,20 @@ module {
     %x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
     %y0s = memref.alloc() : memref<7xi32>
     %y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
-    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+
+    // Sort "parallel arrays".
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 7, 8, 10, 9, 6 )
+    call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-    call @storeValuesTo(%x2, %c2, %c4, %c4, %c7, %c9)
+    call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-
-    // Sort "parallel arrays".
-    // CHECK: ( 0, 1, 2, 5, 10 )
-    // CHECK: ( 3, 3, 1, 10, 1 )
-    // CHECK: ( 4, 9, 4, 7, 2 )
-    // CHECK: ( 8, 7, 10, 9, 6 )
     sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0
       : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
     %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
@@ -89,6 +101,29 @@ module {
     vector.print %x2v : vector<5xi32>
     %y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %y0v : vector<5xi32>
+    // Stable sort.
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 8, 7, 10, 9, 6 )
+    call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort stable %i5, %x0, %x1, %x2 jointly %y0
+      : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
+    %x0v2s = vector.transfer_read %x0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x0v2s : vector<5xi32>
+    %x1vs = vector.transfer_read %x1[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x1vs : vector<5xi32>
+    %x2vs = vector.transfer_read %x2[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %x2vs : vector<5xi32>
+    %y0vs = vector.transfer_read %y0[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %y0vs : vector<5xi32>
 
     // Release the buffers.
     memref.dealloc %x0 : memref<?xi32>


        


More information about the Mlir-commits mailing list