[Mlir-commits] [mlir] 3b1c86c - [mlir][sparse] Implement heap sort for sparse_tensor.sort.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 2 15:36:43 PST 2023


Author: bixia1
Date: 2023-02-02T15:36:38-08:00
New Revision: 3b1c86cd0f6d6b2c49998545cd5789fd8c2a4201

URL: https://github.com/llvm/llvm-project/commit/3b1c86cd0f6d6b2c49998545cd5789fd8c2a4201
DIFF: https://github.com/llvm/llvm-project/commit/3b1c86cd0f6d6b2c49998545cd5789fd8c2a4201.diff

LOG: [mlir][sparse] Implement heap sort for sparse_tensor.sort.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D142913

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
    mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 5d6f4212a47f4..90ca39fe650d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -42,6 +42,8 @@ static constexpr const char kSortNonstableFuncNamePrefix[] =
     "_sparse_sort_nonstable_";
 static constexpr const char kSortStableFuncNamePrefix[] =
     "_sparse_sort_stable_";
+static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
+static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
 
 using FuncGeneratorType = function_ref<void(
     OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
@@ -680,6 +682,240 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
   builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
 }
 
+/// Computes (n-2)/n, assuming n has index type.
+static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
+                                      Value n) {
+  Value i2 = constantIndex(builder, loc, 2);
+  Value res = builder.create<arith::SubIOp>(loc, n, i2);
+  Value i1 = constantIndex(builder, loc, 1);
+  return builder.create<arith::ShRUIOp>(loc, res, i1);
+}
+
+/// Creates a function to heapify the subtree with root `start` within the full
+/// binary tree in the range of index [first, first + n).
+//
+// The generated IR corresponds to this C like algorithm:
+// void shiftDown(first, start, n, data) {
+//   if (n >= 2) {
+//     child = start - first
+//     if ((n-2)/2 >= child) {
+//       // Left child exists.
+//       child = child * 2 + 1 // Initialize the bigger child to left child.
+//       childIndex = child + first
+//       if (child+1 < n && data[childIndex] < data[childIndex+1])
+//         // Right child exits and is bigger.
+//         childIndex++; child++;
+//       // Shift data[start] down to where it belongs in the subtree.
+//       while (data[start] < data[childIndex) {
+//         swap(data[start], data[childIndex])
+//         start = childIndex
+//         if ((n - 2)/2 >= child) {
+//           // Left child exists.
+//           child = 2*child + 1
+//           childIndex = child + 1
+//           if (child + 1) < n && data[childIndex] < data[childIndex+1]
+//             childIndex++; child++;
+//         }
+//       }
+//     }
+//   }
+// }
+//
+static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
+                                func::FuncOp func, uint64_t nx, uint64_t ny,
+                                bool isCoo, uint32_t nTrailingP) {
+  // The value n is passed in as a trailing parameter.
+  assert(nTrailingP == 1);
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  Location loc = func.getLoc();
+  Value n = entryBlock->getArguments().back();
+  ValueRange args = entryBlock->getArguments().drop_back();
+  Value first = args[loIdx];
+  Value start = args[hiIdx];
+
+  // If (n >= 2).
+  Value c2 = constantIndex(builder, loc, 2);
+  Value condN =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
+  scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
+  builder.setInsertionPointToStart(&ifN.getThenRegion().front());
+  Value child = builder.create<arith::SubIOp>(loc, start, first);
+
+  // If ((n-2)/2 >= child).
+  Value t = createSubTwoDividedByTwo(builder, loc, n);
+  Value condNc =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
+  scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
+
+  builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
+  Value c1 = constantIndex(builder, loc, 1);
+  SmallVector<Value> compareOperands{start, start};
+  uint64_t numXBuffers = isCoo ? 1 : nx;
+  compareOperands.append(args.begin() + xStartIdx,
+                         args.begin() + xStartIdx + numXBuffers);
+  Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
+  FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+      builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
+      compareOperands, createLessThanFunc);
+
+  // Generate code to inspect the children of 'r' and return the larger child
+  // as follows:
+  //   child = r * 2 + 1 // Left child.
+  //   childIndex = child + first
+  //   if (child+1 < n && data[childIndex] < data[childIndex+1])
+  //     childIndex ++; child ++ // Right child is bigger.
+  auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
+    Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
+    lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
+    Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
+    Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
+    Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+                                                rChild, n);
+    SmallVector<Type, 2> ifTypes(2, r.getType());
+    scf::IfOp if1 =
+        builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
+    builder.setInsertionPointToStart(&if1.getThenRegion().front());
+    Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
+    // Compare data[left] < data[right].
+    compareOperands[0] = lChildIdx;
+    compareOperands[1] = rChildIdx;
+    Value cond2 = builder
+                      .create<func::CallOp>(loc, lessThanFunc,
+                                            TypeRange{i1Type}, compareOperands)
+                      .getResult(0);
+    scf::IfOp if2 =
+        builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
+    builder.setInsertionPointToStart(&if2.getThenRegion().front());
+    builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
+    builder.setInsertionPointToStart(&if2.getElseRegion().front());
+    builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
+    builder.setInsertionPointAfter(if2);
+    builder.create<scf::YieldOp>(loc, if2.getResults());
+    builder.setInsertionPointToStart(&if1.getElseRegion().front());
+    builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
+    builder.setInsertionPointAfter(if1);
+    return std::make_pair(if1.getResult(0), if1.getResult(1));
+  };
+
+  Value childIdx;
+  std::tie(child, childIdx) = getLargerChild(child);
+
+  // While (data[start] < data[childIndex]).
+  SmallVector<Type, 3> types(3, child.getType());
+  scf::WhileOp whileOp = builder.create<scf::WhileOp>(
+      loc, types, SmallVector<Value, 2>{start, child, childIdx});
+
+  // The before-region of the WhileOp.
+  SmallVector<Location, 3> locs(3, loc);
+  Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+  builder.setInsertionPointToEnd(before);
+  start = before->getArgument(0);
+  childIdx = before->getArgument(2);
+  compareOperands[0] = start;
+  compareOperands[1] = childIdx;
+  Value cond = builder
+                   .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+                                         compareOperands)
+                   .getResult(0);
+  builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
+
+  // The after-region of the WhileOp.
+  Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
+  start = after->getArgument(0);
+  child = after->getArgument(1);
+  childIdx = after->getArgument(2);
+  SmallVector<Value> swapOperands{start, childIdx};
+  swapOperands.append(args.begin() + xStartIdx, args.end());
+  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  start = childIdx;
+  Value cond2 =
+      builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
+  scf::IfOp if2 = builder.create<scf::IfOp>(
+      loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
+  builder.setInsertionPointToStart(&if2.getThenRegion().front());
+  auto [newChild, newChildIdx] = getLargerChild(child);
+  builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
+  builder.setInsertionPointToStart(&if2.getElseRegion().front());
+  builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
+  builder.setInsertionPointAfter(if2);
+  builder.create<scf::YieldOp>(
+      loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
+
+  builder.setInsertionPointAfter(ifN);
+  builder.create<func::ReturnOp>(loc);
+}
+
+/// Creates a function to perform heap sort on the values in the range of index
+/// [lo, hi) with the assumption hi - lo >= 2.
+//
+// The generate IR corresponds to this C like algorithm:
+// void heapSort(lo, hi, data) {
+//   n = hi - lo
+//   for i = (n-2)/2 downto 0
+//     shiftDown(lo, lo+i, n)
+//
+//   for l = n downto 2
+//      swap(lo, lo+l-1)
+//      shiftdown(lo, lo, l-1)
+// }
+static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
+                               func::FuncOp func, uint64_t nx, uint64_t ny,
+                               bool isCoo, uint32_t nTrailingP) {
+  // Heap sort function doesn't have trailing parameters.
+  (void)nTrailingP;
+  assert(nTrailingP == 0);
+  OpBuilder::InsertionGuard insertionGuard(builder);
+  Block *entryBlock = func.addEntryBlock();
+  builder.setInsertionPointToStart(entryBlock);
+
+  Location loc = func.getLoc();
+  ValueRange args = entryBlock->getArguments();
+  Value lo = args[loIdx];
+  Value hi = args[hiIdx];
+  Value n = builder.create<arith::SubIOp>(loc, hi, lo);
+
+  // For i = (n-2)/2 downto 0.
+  Value c0 = constantIndex(builder, loc, 0);
+  Value c1 = constantIndex(builder, loc, 1);
+  Value s = createSubTwoDividedByTwo(builder, loc, n);
+  Value up = builder.create<arith::AddIOp>(loc, s, c1);
+  scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
+  builder.setInsertionPointToStart(forI.getBody());
+  Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
+  Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
+  SmallVector<Value> shiftDownOperands = {lo, lopi};
+  shiftDownOperands.append(args.begin() + xStartIdx, args.end());
+  shiftDownOperands.push_back(n);
+  FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
+      builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
+      shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
+  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
+                               shiftDownOperands);
+
+  builder.setInsertionPointAfter(forI);
+  // For l = n downto 2.
+  up = builder.create<arith::SubIOp>(loc, n, c1);
+  scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
+  builder.setInsertionPointToStart(forL.getBody());
+  Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
+  Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
+  loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
+  SmallVector<Value> swapOperands{lo, loplm1};
+  swapOperands.append(args.begin() + xStartIdx, args.end());
+  createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+  shiftDownOperands[1] = lo;
+  shiftDownOperands[shiftDownOperands.size() - 1] =
+      builder.create<arith::SubIOp>(loc, l, c1);
+  builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
+                               shiftDownOperands);
+
+  builder.setInsertionPointAfter(forL);
+  builder.create<func::ReturnOp>(loc);
+}
+
 /// Creates a function to perform quick sort on the value in the range of
 /// index [lo, hi).
 //
@@ -836,14 +1072,27 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
     }
     operands.push_back(v);
   }
-  bool isStable =
-      (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable);
+
   auto insertPoint = op->template getParentOfType<func::FuncOp>();
-  SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix
-                                    : kSortNonstableFuncNamePrefix);
-  FuncGeneratorType funcGenerator =
-      isStable ? createSortStableFunc : createSortNonstableFunc;
+  SmallString<32> funcName;
+  FuncGeneratorType funcGenerator;
   uint32_t nTrailingP = 0;
+  switch (op.getAlgorithm()) {
+  case SparseTensorSortKind::HybridQuickSort:
+  case SparseTensorSortKind::QuickSort:
+    funcName = kSortNonstableFuncNamePrefix;
+    funcGenerator = createSortNonstableFunc;
+    break;
+  case SparseTensorSortKind::InsertionSortStable:
+    funcName = kSortStableFuncNamePrefix;
+    funcGenerator = createSortStableFunc;
+    break;
+  case SparseTensorSortKind::HeapSort:
+    funcName = kHeapSortFuncNamePrefix;
+    funcGenerator = createHeapSortFunc;
+    break;
+  }
+
   FlatSymbolRefAttr func =
       getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
                                ny, isCoo, operands, funcGenerator, nTrailingP);

diff  --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index b9d56f470d934..68c8366eb822e 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -190,6 +190,20 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
 
 // -----
 
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
+// CHECK-DAG:     func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL:   func.func @sparse_sort_3d_heap
+func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+  sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+  return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}
+
+// -----
+
 // Only check the generated supporting functions. We have integration test to
 // verify correctness of the generated code.
 //
@@ -217,3 +231,16 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
   return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
 }
 
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG:     func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG:     func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG:     func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL:   func.func @sparse_sort_coo_heap
+func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+  sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+  return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
\ No newline at end of file

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 19585488dad7d..3c2d9cf62e5c8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -56,6 +56,10 @@ module {
     // CHECK: [10,  2,  0,  5,  1]
     sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+    // Heap sort.
+    // CHECK: [10,  2,  0,  5,  1]
+    sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
+    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
 
     // Sort the first 4 elements, with the last valid value untouched.
     // CHECK: [0,  2,  5, 10,  1]
@@ -67,6 +71,12 @@ module {
       : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
     sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
     call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+    // Heap sort.
+    // CHECK: [0,  2,  5,  10,  1]
+    call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
+    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
 
     // Prepare more buffers of 
diff erent dimensions.
     %x1s = memref.alloc() : memref<10xi32>
@@ -114,6 +124,25 @@ module {
     call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
     call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
     call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
+    // Heap sort.
+    // CHECK: [1,  1,  2,  5,  10]
+    // CHECK: [3,  3,  1,  10,  1
+    // CHECK: [9,  9,  4,  7,  2
+    // CHECK: [7,  8,  10,  9,  6
+    call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0
+      : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
+    call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+    call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
+    call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
+    call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
 
     // Release the buffers.
     memref.dealloc %x0 : memref<?xi32>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index b0ff0cf19c767..46e1020f8d88e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -132,6 +132,34 @@ module {
     vector.print %y0v2 : vector<5xi32>
     %y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
     vector.print %y1v2 : vector<5xi32>
+    // Heap sort.
+    // CHECK: ( 1, 1, 2, 5, 10 )
+    // CHECK: ( 3, 3, 1, 10, 1 )
+    // CHECK: ( 9, 9, 4, 7, 2 )
+    // CHECK: ( 7, 8, 10, 9, 6 )
+    // CHECK: ( 7, 4, 7, 9, 5 )
+    call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7)
+      : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+    call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
+      : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+    sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+      : memref<?xi32> jointly memref<?xi32>
+    %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x0v3 : vector<5xi32>
+    %x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x1v3 : vector<5xi32>
+    %x2v3 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %x2v3 : vector<5xi32>
+    %y0v3 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+    vector.print %y0v3 : vector<5xi32>
+    %y1v3 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
+    vector.print %y1v3 : vector<5xi32>
 
     // Release the buffers.
     memref.dealloc %xy : memref<?xi32>


        


More information about the Mlir-commits mailing list