[Mlir-commits] [mlir] 4f729d5 - [mlir][sparse] Add rewriting rules for sparse_tensor.sort_coo.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 14 08:49:00 PST 2022
Author: bixia1
Date: 2022-11-14T08:48:53-08:00
New Revision: 4f729d5a7056bbb59621c1332598db924c2f1fd6
URL: https://github.com/llvm/llvm-project/commit/4f729d5a7056bbb59621c1332598db924c2f1fd6
DIFF: https://github.com/llvm/llvm-project/commit/4f729d5a7056bbb59621c1332598db924c2f1fd6.diff
LOG: [mlir][sparse] Add rewriting rules for sparse_tensor.sort_coo.
Refactor the rewriting of sparse_tensor.sort to support the implementation of
sparse_tensor.sort_coo.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137522
Added:
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 52a6aff752792..64facdc0a4113 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -529,10 +529,10 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
`xs` values and some `ys` values are put in the linear buffer `xy`. The
optional index attribute `nx` provides the number of `xs` values in `xy`.
- When `ns` is not explicitly specified, its value is 1. The optional index
+ When `nx` is not explicitly specified, its value is 1. The optional index
attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
- explicitly specified, its value is 0. This instruction supports the TACO
- COO style storage format for better sorting performance.
+ explicitly specified, its value is 0. This instruction supports a more
+ efficient way to store the COO definition in sparse tensor type.
The buffer xy should have a dimension not less than n * (nx + ny) while the
buffers in `ys` should have a dimension not less than `n`. The behavior of
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index d0564cabad314..c556b0d3afe29 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -43,32 +43,42 @@ static constexpr const char kSortNonstableFuncNamePrefix[] =
static constexpr const char kSortStableFuncNamePrefix[] =
"_sparse_sort_stable_";
-using FuncGeneratorType =
- function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, size_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+ uint64_t, uint64_t, bool)>;
/// Constructs a function name with this format to facilitate quick sort:
-/// <namePrefix><dim>_<x type>_<y0 type>..._<yn type>
+/// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
+/// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
- StringRef namePrefix, size_t dim,
+ StringRef namePrefix, uint64_t nx,
+ uint64_t ny, bool isCoo,
ValueRange operands) {
nameOstream
- << namePrefix << dim << "_"
+ << namePrefix << nx << "_"
<< operands[xStartIdx].getType().cast<MemRefType>().getElementType();
- for (Value v : operands.drop_front(xStartIdx + dim))
+ if (isCoo)
+ nameOstream << "_coo_" << ny;
+
+ uint64_t yBufferOffset = isCoo ? 1 : nx;
+ for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
nameOstream << "_" << v.getType().cast<MemRefType>().getElementType();
}
/// Looks up a function that is appropriate for the given operands being
-/// sorted, and creates such a function if it doesn't exist yet.
+/// sorted, and creates such a function if it doesn't exist yet. The
+/// parameters `nx` and `ny` tell the number of x and y values provided
+/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
+/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
static FlatSymbolRefAttr
getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
TypeRange resultTypes, StringRef namePrefix,
- size_t dim, ValueRange operands,
- FuncGeneratorType createFunc) {
+ uint64_t nx, uint64_t ny, bool isCoo,
+ ValueRange operands, FuncGeneratorType createFunc) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
- getMangledSortHelperFuncName(nameOstream, namePrefix, dim, operands);
+ getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+ operands);
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
MLIRContext *context = module.getContext();
@@ -84,12 +94,61 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
- createFunc(builder, module, func, dim);
+ createFunc(builder, module, func, nx, ny, isCoo);
}
return result;
}
+/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
+/// The code to process the value pairs is generated by `bodyBuilder`.
+static void forEachIJPairInXs(
+ OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
+ bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+ Value iOffset, jOffset;
+ if (isCoo) {
+ Value cstep = constantIndex(builder, loc, nx + ny);
+ iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+ jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
+ }
+ for (uint64_t k = 0; k < nx; k++) {
+ scf::IfOp ifOp;
+ Value i, j, buffer;
+ if (isCoo) {
+ Value ck = constantIndex(builder, loc, k);
+ i = builder.create<arith::AddIOp>(loc, ck, iOffset);
+ j = builder.create<arith::AddIOp>(loc, ck, jOffset);
+ buffer = args[xStartIdx];
+ } else {
+ i = args[0];
+ j = args[1];
+ buffer = args[xStartIdx + k];
+ }
+ bodyBuilder(k, i, j, buffer);
+ }
+}
+
+/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
+/// The code to process the value pairs is generated by `bodyBuilder`.
+static void forEachIJPairInAllBuffers(
+ OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
+ bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+
+ // Create code for the first (nx + ny) buffers. When isCoo==true, these
+ // logical buffers are all from the xy buffer of the sort_coo operator.
+ forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+
+ uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+
+ // Create code for the remaining buffers.
+ Value i = args[0];
+ Value j = args[1];
+ for (const auto &arg :
+ llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
+ bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+ }
+}
+
/// Creates a code block for swapping the values in index i and j for all the
/// buffers.
//
@@ -101,21 +160,23 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
// swap(y0[i], y0[j]);
// ...
// swap(yn[i], yn[j]);
-static void createSwap(OpBuilder &builder, Location loc, ValueRange args) {
- Value i = args[0];
- Value j = args[1];
- for (auto arg : args.drop_front(xStartIdx)) {
- Value vi = builder.create<memref::LoadOp>(loc, arg, i);
- Value vj = builder.create<memref::LoadOp>(loc, arg, j);
- builder.create<memref::StoreOp>(loc, vj, arg, i);
- builder.create<memref::StoreOp>(loc, vi, arg, j);
- }
+static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
+ uint64_t nx, uint64_t ny, bool isCoo) {
+ auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
+ Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
+ Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
+ builder.create<memref::StoreOp>(loc, vj, buffer, i);
+ builder.create<memref::StoreOp>(loc, vi, buffer, j);
+ };
+
+ forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
}
/// Creates a function to compare all the (xs[i], xs[j]) pairs. The method to
/// compare each pair is create via `compareBuilder`.
static void createCompareFuncImplementation(
- OpBuilder &builder, ModuleOp unused, func::FuncOp func, size_t dim,
+ OpBuilder &builder, ModuleOp unused, func::FuncOp func, uint64_t nx,
+ uint64_t ny, bool isCoo,
function_ref<scf::IfOp(OpBuilder &, Location, Value, Value, Value, bool)>
compareBuilder) {
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -126,17 +187,18 @@ static void createCompareFuncImplementation(
ValueRange args = entryBlock->getArguments();
scf::IfOp topIfOp;
- for (const auto &item : llvm::enumerate(args.slice(xStartIdx, dim))) {
- scf::IfOp ifOp = compareBuilder(builder, loc, args[0], args[1],
- item.value(), (item.index() == dim - 1));
- if (item.index() == 0) {
+ auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
+ scf::IfOp ifOp = compareBuilder(builder, loc, i, j, buffer, (k == nx - 1));
+ if (k == 0) {
topIfOp = ifOp;
} else {
OpBuilder::InsertionGuard insertionGuard(builder);
builder.setInsertionPointAfter(ifOp);
builder.create<scf::YieldOp>(loc, ifOp.getResult(0));
}
- }
+ };
+
+ forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
builder.setInsertionPointAfter(topIfOp);
builder.create<func::ReturnOp>(loc, topIfOp.getResult(0));
@@ -180,8 +242,10 @@ static scf::IfOp createEqCompare(OpBuilder &builder, Location loc, Value i,
// else if (x2[2] != x2[j]))
// and so on ...
static void createEqCompareFunc(OpBuilder &builder, ModuleOp unused,
- func::FuncOp func, size_t dim) {
- createCompareFuncImplementation(builder, unused, func, dim, createEqCompare);
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo) {
+ createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
+ createEqCompare);
}
/// Generates an if-statement to compare whether x[i] is less than x[j].
@@ -238,8 +302,9 @@ static scf::IfOp createLessThanCompare(OpBuilder &builder, Location loc,
// else if (x1[j] < x1[i]))
// and so on ...
static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
- func::FuncOp func, size_t dim) {
- createCompareFuncImplementation(builder, unused, func, dim,
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo) {
+ createCompareFuncImplementation(builder, unused, func, nx, ny, isCoo,
createLessThanCompare);
}
@@ -257,7 +322,8 @@ static void createLessThanFunc(OpBuilder &builder, ModuleOp unused,
// return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, size_t dim) {
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
@@ -292,12 +358,13 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
// Compare xs[p] < xs[mid].
SmallVector<Value, 6> compareOperands{p, mid};
+ uint64_t numXBuffers = isCoo ? 1 : nx;
compareOperands.append(args.begin() + xStartIdx,
- args.begin() + xStartIdx + dim);
+ args.begin() + xStartIdx + numXBuffers);
Type i1Type = IntegerType::get(module.getContext(), 1, IntegerType::Signless);
- FlatSymbolRefAttr lessThanFunc =
- getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
- dim, compareOperands, createLessThanFunc);
+ FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+ builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
+ compareOperands, createLessThanFunc);
Value cond2 = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
@@ -324,7 +391,8 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
/// xs[i] == xs[p].
static std::pair<Value, Value>
createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
- ValueRange xs, Value i, Value p, size_t dim, int step) {
+ ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
+ bool isCoo, int step) {
Location loc = func.getLoc();
scf::WhileOp whileOp =
builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
@@ -344,9 +412,9 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands.append(xs.begin(), xs.end());
MLIRContext *context = module.getContext();
Type i1Type = IntegerType::get(context, 1, IntegerType::Signless);
- FlatSymbolRefAttr lessThanFunc =
- getMangledSortHelperFunc(builder, func, {i1Type}, kLessThanFuncNamePrefix,
- dim, compareOperands, createLessThanFunc);
+ FlatSymbolRefAttr lessThanFunc = getMangledSortHelperFunc(
+ builder, func, {i1Type}, kLessThanFuncNamePrefix, nx, ny, isCoo,
+ compareOperands, createLessThanFunc);
Value cond = builder
.create<func::CallOp>(loc, lessThanFunc, TypeRange{i1Type},
compareOperands)
@@ -365,8 +433,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands[0] = i;
compareOperands[1] = p;
FlatSymbolRefAttr compareEqFunc = getMangledSortHelperFunc(
- builder, func, {i1Type}, kCompareEqFuncNamePrefix, dim, compareOperands,
- createEqCompareFunc);
+ builder, func, {i1Type}, kCompareEqFuncNamePrefix, nx, ny, isCoo,
+ compareOperands, createEqCompareFunc);
Value compareEq =
builder
.create<func::CallOp>(loc, compareEqFunc, TypeRange{i1Type},
@@ -405,7 +473,8 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
// return p
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, size_t dim) {
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
@@ -442,11 +511,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
j = after->getArgument(1);
p = after->getArgument(2);
- auto [iresult, iCompareEq] = createScanLoop(
- builder, module, func, args.slice(xStartIdx, dim), i, p, dim, 1);
+ uint64_t numXBuffers = isCoo ? 1 : nx;
+ auto [iresult, iCompareEq] =
+ createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
+ i, p, nx, ny, isCoo, 1);
i = iresult;
- auto [jresult, jCompareEq] = createScanLoop(
- builder, module, func, args.slice(xStartIdx, dim), j, p, dim, -1);
+ auto [jresult, jCompareEq] =
+ createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
+ j, p, nx, ny, isCoo, -1);
j = jresult;
// If i < j:
@@ -455,7 +527,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value, 6> swapOperands{i, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands);
+ createSwap(builder, loc, swapOperands, nx, ny, isCoo);
// If the pivot is moved, update p with the new pivot.
Value icond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -515,7 +587,8 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
// }
// }
static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, size_t dim) {
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
@@ -532,8 +605,8 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
// The if-stmt true branch.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
- builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, dim,
- args, createPartitionFunc);
+ builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
+ ny, isCoo, args, createPartitionFunc);
auto p = builder.create<func::CallOp>(
loc, partitionFunc, TypeRange{IndexType::get(context)}, ValueRange(args));
@@ -567,7 +640,8 @@ static void createSortNonstableFunc(OpBuilder &builder, ModuleOp module,
// }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, size_t dim) {
+ func::FuncOp func, uint64_t nx, uint64_t ny,
+ bool isCoo) {
OpBuilder::InsertionGuard insertionGuard(builder);
Block *entryBlock = func.addEntryBlock();
builder.setInsertionPointToStart(entryBlock);
@@ -587,20 +661,23 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
// Binary search to find the insertion point p.
SmallVector<Value, 6> operands{lo, i};
- operands.append(args.begin() + xStartIdx, args.begin() + xStartIdx + dim);
+ operands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
- builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
- dim, operands, createBinarySearchFunc);
+ builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
+ ny, isCoo, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
.getResult(0);
// Move the value at data[i] to a temporary location.
- ValueRange data = args.drop_front(xStartIdx);
+ operands[0] = operands[1] = i;
SmallVector<Value, 6> d;
- for (Value v : data)
- d.push_back(builder.create<memref::LoadOp>(loc, v, i));
+ forEachIJPairInAllBuffers(
+ builder, loc, operands, nx, ny, isCoo,
+ [&](uint64_t unused, Value i, Value unused2, Value buffer) {
+ d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
+ });
// Start the inner for-stmt with induction variable j, for moving data[p..i)
// to data[p+1..i+1).
@@ -610,21 +687,58 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(forOpJ.getBody());
Value j = forOpJ.getInductionVar();
Value imj = builder.create<arith::SubIOp>(loc, i, j);
- Value imjm1 = builder.create<arith::SubIOp>(loc, imj, c1);
- for (Value v : data) {
- Value t = builder.create<memref::LoadOp>(loc, v, imjm1);
- builder.create<memref::StoreOp>(loc, t, v, imj);
- }
+ operands[1] = imj;
+ operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
+ forEachIJPairInAllBuffers(
+ builder, loc, operands, nx, ny, isCoo,
+ [&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
+ Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
+ builder.create<memref::StoreOp>(loc, t, buffer, imj);
+ });
// Store the value at data[i] to data[p].
builder.setInsertionPointAfter(forOpJ);
- for (auto it : llvm::zip(d, data))
- builder.create<memref::StoreOp>(loc, std::get<0>(it), std::get<1>(it), p);
+ operands[0] = operands[1] = p;
+ forEachIJPairInAllBuffers(
+ builder, loc, operands, nx, ny, isCoo,
+ [&](uint64_t k, Value p, Value usused, Value buffer) {
+ builder.create<memref::StoreOp>(loc, d[k], buffer, p);
+ });
builder.setInsertionPointAfter(forOpI);
builder.create<func::ReturnOp>(loc);
}
+/// Implements the rewriting for operator sort and sort_coo.
+template <typename OpTy>
+LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
+ uint64_t ny, bool isCoo,
+ PatternRewriter &rewriter) {
+ Location loc = op.getLoc();
+ SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
+
+ // Convert `values` to have dynamic shape and append them to `operands`.
+ for (Value v : xys) {
+ auto mtp = v.getType().cast<MemRefType>();
+ if (!mtp.isDynamicDim(0)) {
+ auto newMtp =
+ MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
+ v = rewriter.create<memref::CastOp>(loc, newMtp, v);
+ }
+ operands.push_back(v);
+ }
+ auto insertPoint = op->template getParentOfType<func::FuncOp>();
+ SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
+ : kSortNonstableFuncNamePrefix);
+ FuncGeneratorType funcGenerator =
+ op.getStable() ? createSortStableFunc : createSortNonstableFunc;
+ FlatSymbolRefAttr func =
+ getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
+ ny, isCoo, operands, funcGenerator);
+ rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
+ return success();
+}
+
//===---------------------------------------------------------------------===//
// The actual sparse buffer rewriting rules.
//===---------------------------------------------------------------------===//
@@ -755,34 +869,33 @@ struct SortRewriter : public OpRewritePattern<SortOp> {
LogicalResult matchAndRewrite(SortOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- SmallVector<Value, 6> operands{constantIndex(rewriter, loc, 0), op.getN()};
-
- // Convert `values` to have dynamic shape and append them to `operands`.
- auto addValues = [&](ValueRange values) {
- for (Value v : values) {
- auto mtp = v.getType().cast<MemRefType>();
- if (!mtp.isDynamicDim(0)) {
- auto newMtp =
- MemRefType::get({ShapedType::kDynamicSize}, mtp.getElementType());
- v = rewriter.create<memref::CastOp>(loc, newMtp, v);
- }
- operands.push_back(v);
- }
- };
- ValueRange xs = op.getXs();
- addValues(xs);
- addValues(op.getYs());
- auto insertPoint = op->getParentOfType<func::FuncOp>();
- SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
- : kSortNonstableFuncNamePrefix);
- FuncGeneratorType funcGenerator =
- op.getStable() ? createSortStableFunc : createSortNonstableFunc;
- FlatSymbolRefAttr func =
- getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
- xs.size(), operands, funcGenerator);
- rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
- return success();
+ SmallVector<Value, 6> xys(op.getXs());
+ xys.append(op.getYs().begin(), op.getYs().end());
+ return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
+ /*isCoo=*/false, rewriter);
+ }
+};
+
+/// Sparse rewriting rule for the sort_coo operator.
+struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
+public:
+ using OpRewritePattern<SortCooOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SortCooOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value, 6> xys;
+ xys.push_back(op.getXy());
+ xys.append(op.getYs().begin(), op.getYs().end());
+ uint64_t nx = 1;
+ if (auto nxAttr = op.getNxAttr())
+ nx = nxAttr.getInt();
+
+ uint64_t ny = 0;
+ if (auto nyAttr = op.getNyAttr())
+ ny = nyAttr.getInt();
+
+ return matchAndRewriteSortOp(op, xys, nx, ny,
+ /*isCoo=*/true, rewriter);
}
};
@@ -796,5 +909,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
bool enableBufferInitialization) {
patterns.add<PushBackRewriter>(patterns.getContext(),
enableBufferInitialization);
- patterns.add<SortRewriter>(patterns.getContext());
+ patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index f74eb5f3b0313..c153dcd3d7057 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -173,6 +173,7 @@ struct SparseTensorCodegenPass
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addLegalOp<SortOp>();
+ target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index f5634524f7e66..18140def32b7f 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -194,3 +194,33 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_compare_eq_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL: func.func @sparse_sort_coo
+func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+ sparse_tensor.sort_coo %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
+
+// -----
+
+// Only check the generated supporting functions. We have integration test to
+// verify correctness of the generated code.
+//
+// CHECK-DAG: func.func private @_sparse_less_than_2_index_coo_1(%arg0: index, %arg1: index, %arg2: memref<?xindex>) -> i1 {
+// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-LABEL: func.func @sparse_sort_coo_stable
+func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
+ sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
+}
+
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
new file mode 100644
index 0000000000000..2efd2e481b710
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -0,0 +1,134 @@
+// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+module {
+ // Stores 5 values to the memref buffer.
+ func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
+ %v3: i32, %v4: i32) -> () {
+ %i0 = arith.constant 0 : index
+ %i1 = arith.constant 1 : index
+ %i2 = arith.constant 2 : index
+ %i3 = arith.constant 3 : index
+ %i4 = arith.constant 4 : index
+ memref.store %v0, %b[%i0] : memref<?xi32>
+ memref.store %v1, %b[%i1] : memref<?xi32>
+ memref.store %v2, %b[%i2] : memref<?xi32>
+ memref.store %v3, %b[%i3] : memref<?xi32>
+ memref.store %v4, %b[%i4] : memref<?xi32>
+ return
+ }
+
+ // Stores 5 values to the memref buffer.
+ func.func @storeValuesToStrided(%b: memref<?xi32, strided<[4], offset: ?>>, %v0: i32, %v1: i32, %v2: i32,
+ %v3: i32, %v4: i32) -> () {
+ %i0 = arith.constant 0 : index
+ %i1 = arith.constant 1 : index
+ %i2 = arith.constant 2 : index
+ %i3 = arith.constant 3 : index
+ %i4 = arith.constant 4 : index
+ memref.store %v0, %b[%i0] : memref<?xi32, strided<[4], offset: ?>>
+ memref.store %v1, %b[%i1] : memref<?xi32, strided<[4], offset: ?>>
+ memref.store %v2, %b[%i2] : memref<?xi32, strided<[4], offset: ?>>
+ memref.store %v3, %b[%i3] : memref<?xi32, strided<[4], offset: ?>>
+ memref.store %v4, %b[%i4] : memref<?xi32, strided<[4], offset: ?>>
+ return
+ }
+
+ // The main driver.
+ func.func @entry() {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %c2 = arith.constant 2 : i32
+ %c3 = arith.constant 3 : i32
+ %c4 = arith.constant 4 : i32
+ %c5 = arith.constant 5 : i32
+ %c6 = arith.constant 6 : i32
+ %c7 = arith.constant 7 : i32
+ %c8 = arith.constant 8 : i32
+ %c9 = arith.constant 9 : i32
+ %c10 = arith.constant 10 : i32
+ %c100 = arith.constant 100 : i32
+
+ %i0 = arith.constant 0 : index
+ %i1 = arith.constant 1 : index
+ %i2 = arith.constant 2 : index
+ %i3 = arith.constant 3 : index
+ %i4 = arith.constant 4 : index
+ %i5 = arith.constant 5 : index
+
+ // Prepare a buffer for x0, x1, x2, y0 and a buffer for y1.
+ %xys = memref.alloc() : memref<20xi32>
+ %xy = memref.cast %xys : memref<20xi32> to memref<?xi32>
+ %x0 = memref.subview %xy[%i0][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+ %x1 = memref.subview %xy[%i1][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+ %x2 = memref.subview %xy[%i2][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+ %y0 = memref.subview %xy[%i3][%i5][%i4] : memref<?xi32> to memref<?xi32, strided<[4], offset: ?>>
+ %y1s = memref.alloc() : memref<7xi32>
+ %y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
+
+ // Sort "parallel arrays".
+ // CHECK: ( 1, 1, 2, 5, 10 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 9, 9, 4, 7, 2 )
+ // CHECK: ( 8, 7, 10, 9, 6 )
+ // CHECK: ( 4, 7, 7, 9, 5 )
+ call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ sparse_tensor.sort_coo %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ : memref<?xi32> jointly memref<?xi32>
+ %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x0v : vector<5xi32>
+ %x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x1v : vector<5xi32>
+ %x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x2v : vector<5xi32>
+ %y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %y0v : vector<5xi32>
+ %y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %y1v : vector<5xi32>
+ // Stable sort.
+ // CHECK: ( 1, 1, 2, 5, 10 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 9, 9, 4, 7, 2 )
+ // CHECK: ( 8, 7, 10, 9, 6 )
+ // CHECK: ( 4, 7, 7, 9, 5 )
+ call @storeValuesToStrided(%x0, %c10, %c2, %c1, %c5, %c1)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%x1, %c1, %c1, %c3, %c10, %c3)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%x2, %c2, %c4, %c9, %c7, %c9)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesToStrided(%y0, %c6, %c10, %c8, %c9, %c7)
+ : (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
+ call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
+ : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
+ sparse_tensor.sort_coo stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ : memref<?xi32> jointly memref<?xi32>
+ %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x0v2 : vector<5xi32>
+ %x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x1v2 : vector<5xi32>
+ %x2v2 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x2v2 : vector<5xi32>
+ %y0v2 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %y0v2 : vector<5xi32>
+ %y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
+ vector.print %y1v2 : vector<5xi32>
+
+ // Release the buffers.
+ memref.dealloc %xy : memref<?xi32>
+ memref.dealloc %y1 : memref<?xi32>
+ return
+ }
+}
More information about the Mlir-commits
mailing list