[Mlir-commits] [mlir] 3b1c86c - [mlir][sparse] Implement heap sort for sparse_tensor.sort.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 2 15:36:43 PST 2023
Author: bixia1
Date: 2023-02-02T15:36:38-08:00
New Revision: 3b1c86cd0f6d6b2c49998545cd5789fd8c2a4201
URL: https://github.com/llvm/llvm-project/commit/3b1c86cd0f6d6b2c49998545cd5789fd8c2a4201
DIFF: https://github.com/llvm/llvm-project/commit/3b1c86cd0f6d6b2c49998545cd5789fd8c2a4201.diff
LOG: [mlir][sparse] Implement heap sort for sparse_tensor.sort.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142913
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 5d6f4212a47f4..90ca39fe650d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -42,6 +42,8 @@ static constexpr const char kSortNonstableFuncNamePrefix[] =
"_sparse_sort_nonstable_";
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";
+static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
+static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
using FuncGeneratorType = function_ref<void(
OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
@@ -680,6 +682,240 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.create<func::ReturnOp>(loc, whileOp.getResult(2));
}
+/// Computes (n-2)/n, assuming n has index type.
+static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
+ Value n) {
+ Value i2 = constantIndex(builder, loc, 2);
+ Value res = builder.create<arith::SubIOp>(loc, n, i2);
+ Value i1 = constantIndex(builder, loc, 1);
+ return builder.create<arith::ShRUIOp>(loc, res, i1);
+}
+
+/// Creates a function to heapify the subtree with root `start` within the full
+/// binary tree in the range of index [first, first + n).
+//
+// The generated IR corresponds to this C like algorithm:
+// void shiftDown(first, start, n, data) {
+// if (n >= 2) {
+// child = start - first
+// if ((n-2)/2 >= child) {
+// // Left child exists.
+// child = child * 2 + 1 // Initialize the bigger child to left child.
+// childIndex = child + first
+// if (child+1 < n && data[childIndex] < data[childIndex+1])
+// // Right child exits and is bigger.
+// childIndex++; child++;
+// // Shift data[start] down to where it belongs in the subtree.
+// while (data[start] < data[childIndex) {
+// swap(data[start], data[childIndex])
+// start = childIndex
+// if ((n - 2)/2 >= child) {
+// // Left child exists.
+// child = 2*child + 1
+// childIndex = child + 1
+// if (child + 1) < n && data[childIndex] < data[childIndex+1]
+// childIndex++; child++;
+// }
+// }
+// }
+// }
+// }
+//
+static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo, uint32_t nTrailingP) {
+ // The value n is passed in as a trailing parameter.
+ assert(nTrailingP == 1);
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ Location loc = func.getLoc();
+ Value n = entryBlock->getArguments().back();
+ ValueRange args = entryBlock->getArguments().drop_back();
+ Value first = args[loIdx];
+ Value start = args[hiIdx];
+
+ // If (n >= 2).
+ Value c2 = constantIndex(builder, loc, 2);
+ Value condN =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, n, c2);
+ scf::IfOp ifN = builder.create<scf::IfOp>(loc, condN, /*else=*/false);
+ builder.setInsertionPointToStart(&ifN.getThenRegion().front());
+ Value child = builder.create<arith::SubIOp>(loc, start, first);
+
+ // If ((n-2)/2 >= child).
+ Value t = createSubTwoDividedByTwo(builder, loc, n);
+ Value condNc =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
+ scf::IfOp ifNc = builder.create<scf::IfOp>(loc, condNc, /*else=*/false);
+
+ builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
+ Value c1 = constantIndex(builder, loc, 1);
+ SmallVector<Value> compareOperands{start, start};
+ uint64_t numXBuffers = isCoo ? 1 : nx;
+ compareOperands.append(args.begin() + xStartIdx,
+ args.begin() + xStartIdx + numXBuffers);
+ Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
+ FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+ builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
+ compareOperands, createLessThanFunc);
+
+ // Generate code to inspect the children of 'r' and return the larger child
+ // as follows:
+ // child = r * 2 + 1 // Left child.
+ // childIndex = child + first
+ // if (child+1 < n && data[childIndex] < data[childIndex+1])
+ // childIndex ++; child ++ // Right child is bigger.
+ auto getLargerChild = [&](Value r) -> std::pair<Value, Value> {
+ Value lChild = builder.create<arith::ShLIOp>(loc, r, c1);
+ lChild = builder.create<arith::AddIOp>(loc, lChild, c1);
+ Value lChildIdx = builder.create<arith::AddIOp>(loc, lChild, first);
+ Value rChild = builder.create<arith::AddIOp>(loc, lChild, c1);
+ Value cond1 = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
+ rChild, n);
+ SmallVector<Type, 2> ifTypes(2, r.getType());
+ scf::IfOp if1 =
+ builder.create<scf::IfOp>(loc, ifTypes, cond1, /*else=*/true);
+ builder.setInsertionPointToStart(&if1.getThenRegion().front());
+ Value rChildIdx = builder.create<arith::AddIOp>(loc, rChild, first);
+ // Compare data[left] < data[right].
+ compareOperands[0] = lChildIdx;
+ compareOperands[1] = rChildIdx;
+ Value cond2 = builder
+ .create<func::CallOp>(loc, lessThanFunc,
+ TypeRange{i1Type}, compareOperands)
+ .getResult(0);
+ scf::IfOp if2 =
+ builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
+ builder.setInsertionPointToStart(&if2.getThenRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{rChild, rChildIdx});
+ builder.setInsertionPointToStart(&if2.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
+ builder.setInsertionPointAfter(if2);
+ builder.create<scf::YieldOp>(loc, if2.getResults());
+ builder.setInsertionPointToStart(&if1.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{lChild, lChildIdx});
+ builder.setInsertionPointAfter(if1);
+ return std::make_pair(if1.getResult(0), if1.getResult(1));
+ };
+
+ Value childIdx;
+ std::tie(child, childIdx) = getLargerChild(child);
+
+ // While (data[start] < data[childIndex]).
+ SmallVector<Type, 3> types(3, child.getType());
+ scf::WhileOp whileOp = builder.create<scf::WhileOp>(
+ loc, types, SmallVector<Value, 2>{start, child, childIdx});
+
+ // The before-region of the WhileOp.
+ SmallVector<Location, 3> locs(3, loc);
+ Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
+ builder.setInsertionPointToEnd(before);
+ start = before->getArgument(0);
+ childIdx = before->getArgument(2);
+ compareOperands[0] = start;
+ compareOperands[1] = childIdx;
+ Value cond = builder
+ .create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
+ compareOperands)
+ .getResult(0);
+ builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
+
+ // The after-region of the WhileOp.
+ Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
+ start = after->getArgument(0);
+ child = after->getArgument(1);
+ childIdx = after->getArgument(2);
+ SmallVector<Value> swapOperands{start, childIdx};
+ swapOperands.append(args.begin() + xStartIdx, args.end());
+ createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ start = childIdx;
+ Value cond2 =
+ builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
+ scf::IfOp if2 = builder.create<scf::IfOp>(
+ loc, TypeRange{child.getType(), child.getType()}, cond2, /*else=*/true);
+ builder.setInsertionPointToStart(&if2.getThenRegion().front());
+ auto [newChild, newChildIdx] = getLargerChild(child);
+ builder.create<scf::YieldOp>(loc, ValueRange{newChild, newChildIdx});
+ builder.setInsertionPointToStart(&if2.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, ValueRange{child, childIdx});
+ builder.setInsertionPointAfter(if2);
+ builder.create<scf::YieldOp>(
+ loc, ValueRange{start, if2.getResult(0), if2.getResult(1)});
+
+ builder.setInsertionPointAfter(ifN);
+ builder.create<func::ReturnOp>(loc);
+}
+
+/// Creates a function to perform heap sort on the values in the range of index
+/// [lo, hi) with the assumption hi - lo >= 2.
+//
+// The generate IR corresponds to this C like algorithm:
+// void heapSort(lo, hi, data) {
+// n = hi - lo
+// for i = (n-2)/2 downto 0
+// shiftDown(lo, lo+i, n)
+//
+// for l = n downto 2
+// swap(lo, lo+l-1)
+// shiftdown(lo, lo, l-1)
+// }
+static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo, uint32_t nTrailingP) {
+ // Heap sort function doesn't have trailing parameters.
+ (void)nTrailingP;
+ assert(nTrailingP == 0);
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ Block *entryBlock = func.addEntryBlock();
+ builder.setInsertionPointToStart(entryBlock);
+
+ Location loc = func.getLoc();
+ ValueRange args = entryBlock->getArguments();
+ Value lo = args[loIdx];
+ Value hi = args[hiIdx];
+ Value n = builder.create<arith::SubIOp>(loc, hi, lo);
+
+ // For i = (n-2)/2 downto 0.
+ Value c0 = constantIndex(builder, loc, 0);
+ Value c1 = constantIndex(builder, loc, 1);
+ Value s = createSubTwoDividedByTwo(builder, loc, n);
+ Value up = builder.create<arith::AddIOp>(loc, s, c1);
+ scf::ForOp forI = builder.create<scf::ForOp>(loc, c0, up, c1);
+ builder.setInsertionPointToStart(forI.getBody());
+ Value i = builder.create<arith::SubIOp>(loc, s, forI.getInductionVar());
+ Value lopi = builder.create<arith::AddIOp>(loc, lo, i);
+ SmallVector<Value> shiftDownOperands = {lo, lopi};
+ shiftDownOperands.append(args.begin() + xStartIdx, args.end());
+ shiftDownOperands.push_back(n);
+ FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
+ builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
+ shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
+ builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
+ shiftDownOperands);
+
+ builder.setInsertionPointAfter(forI);
+ // For l = n downto 2.
+ up = builder.create<arith::SubIOp>(loc, n, c1);
+ scf::ForOp forL = builder.create<scf::ForOp>(loc, c0, up, c1);
+ builder.setInsertionPointToStart(forL.getBody());
+ Value l = builder.create<arith::SubIOp>(loc, n, forL.getInductionVar());
+ Value loplm1 = builder.create<arith::AddIOp>(loc, lo, l);
+ loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
+ SmallVector<Value> swapOperands{lo, loplm1};
+ swapOperands.append(args.begin() + xStartIdx, args.end());
+ createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ shiftDownOperands[1] = lo;
+ shiftDownOperands[shiftDownOperands.size() - 1] =
+ builder.create<arith::SubIOp>(loc, l, c1);
+ builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
+ shiftDownOperands);
+
+ builder.setInsertionPointAfter(forL);
+ builder.create<func::ReturnOp>(loc);
+}
+
/// Creates a function to perform quick sort on the value in the range of
/// index [lo, hi).
//
@@ -836,14 +1072,27 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
}
operands.push_back(v);
}
- bool isStable =
- (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable);
+
auto insertPoint = op->template getParentOfType<func::FuncOp>();
- SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix
- : kSortNonstableFuncNamePrefix);
- FuncGeneratorType funcGenerator =
- isStable ? createSortStableFunc : createSortNonstableFunc;
+ SmallString<32> funcName;
+ FuncGeneratorType funcGenerator;
uint32_t nTrailingP = 0;
+ switch (op.getAlgorithm()) {
+ case SparseTensorSortKind::HybridQuickSort:
+ case SparseTensorSortKind::QuickSort:
+ funcName = kSortNonstableFuncNamePrefix;
+ funcGenerator = createSortNonstableFunc;
+ break;
+ case SparseTensorSortKind::InsertionSortStable:
+ funcName = kSortStableFuncNamePrefix;
+ funcGenerator = createSortStableFunc;
+ break;
+ case SparseTensorSortKind::HeapSort:
+ funcName = kHeapSortFuncNamePrefix;
+ funcGenerator = createHeapSortFunc;
+ break;
+ }
+
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
ny, isCoo, operands, funcGenerator, nTrailingP);
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index b9d56f470d934..68c8366eb822e 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -190,6 +190,20 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
// -----
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
+// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
+// CHECK-LABEL: func.func @sparse_sort_3d_heap
+func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
+ sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+ return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+}
+
+// -----
+
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
@@ -217,3 +231,16 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2:
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL: func.func @sparse_sort_coo_heap
+func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+ sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
\ No newline at end of file
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 19585488dad7d..3c2d9cf62e5c8 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -56,6 +56,10 @@ module {
// CHECK: [10, 2, 0, 5, 1]
sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+ // Heap sort.
+ // CHECK: [10, 2, 0, 5, 1]
+ sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
+ call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Sort the first 4 elements, with the last valid value untouched.
// CHECK: [0, 2, 5, 10, 1]
@@ -67,6 +71,12 @@ module {
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+ // Heap sort.
+ // CHECK: [0, 2, 5, 10, 1]
+ call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
+ call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Prepare more buffers of
diff erent dimensions.
%x1s = memref.alloc() : memref<10xi32>
@@ -114,6 +124,25 @@ module {
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
+ // Heap sort.
+ // CHECK: [1, 1, 2, 5, 10]
+ // CHECK: [3, 3, 1, 10, 1
+ // CHECK: [9, 9, 4, 7, 2
+ // CHECK: [7, 8, 10, 9, 6
+ call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0
+ : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
+ call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
+ call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
+ call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
+ call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
// Release the buffers.
memref.dealloc %x0 : memref<?xi32>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index b0ff0cf19c767..46e1020f8d88e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -132,6 +132,34 @@ module {
vector.print %y0v2 : vector<5xi32>
%y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
vector.print %y1v2 : vector<5xi32>
+ // Heap sort.
+ // CHECK: ( 1, 1, 2, 5, 10 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 9, 9, 4, 7, 2 )
+ // CHECK: ( 7, 8, 10, 9, 6 )
+ // CHECK: ( 7, 4, 7, 9, 5 )
+ call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ : memref<?xi32> jointly memref<?xi32>
+ %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x0v3 : vector<5xi32>
+ %x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x1v3 : vector<5xi32>
+ %x2v3 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x2v3 : vector<5xi32>
+ %y0v3 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %y0v3 : vector<5xi32>
+ %y1v3 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %y1v3 : vector<5xi32>
// Release the buffers.
memref.dealloc %xy : memref<?xi32>
More information about the Mlir-commits
mailing list