[Mlir-commits] [mlir] [mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (PR #66722)
Peiming Liu
llvmlistbot at llvm.org
Mon Sep 18 17:08:24 PDT 2023
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/66722
The functionality of the two operations are largely overlapped, let's simplify it and only use one of them.
>From 3daa6dc70c102f924f8e8e88d4eb8b605aff6f22 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Sep 2023 21:48:00 +0000
Subject: [PATCH 1/4] [mlir][sparse] unify sparse_tensor.sort/sort_coo
operations into one.
---
.../SparseTensor/IR/SparseTensorOps.td | 53 +---
.../SparseTensor/IR/SparseTensorDialect.cpp | 30 +--
.../Transforms/SparseBufferRewriting.cpp | 226 +++++++++---------
.../Transforms/SparseTensorCodegen.cpp | 12 +-
.../Transforms/SparseTensorPasses.cpp | 1 -
.../Transforms/SparseTensorRewriting.cpp | 38 ++-
6 files changed, 137 insertions(+), 223 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 94301dbcd9f7b42..d83d1ba03feb848 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -762,61 +762,10 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
// Sparse Tensor Sorting Operations.
//===----------------------------------------------------------------------===//
-def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
- Arguments<(ins Index:$n,
- Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
- Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
- SparseTensorSortKindAttr:$algorithm)> {
- string summary = "Sorts the arrays in xs and ys lexicographically on the "
- "integral values found in the xs list";
- string description = [{
- Lexicographically sort the first `n` values in `xs` along with the values in
- `ys`. Conceptually, the values being sorted are tuples produced by
- `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
- along with values in `xs`, but values in `ys` don't affect the
- lexicographical order. The order in which arrays appear in `xs` affects the
- sorting result. The operator updates `xs` and `ys` in place with the result
- of the sorting.
-
- For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
- "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
- output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
-
- Buffers in `xs` needs to have the same integral element type while buffers
- in `ys` can have different numeric element types. All buffers in `xs` and
- `ys` should have a dimension not less than `n`. The behavior of the operator
- is undefined if this condition is not met. The operator requires at least
- one buffer in `xs` while `ys` can be empty.
-
- The enum attribute `algorithm` indicates the sorting algorithm used to
- implement the operator: hybrid_quick_sort, insertion_sort_stable,
- quick_sort, or heap_sort.
-
- Note that this operation is "impure" in the sense that its behavior is
- solely defined by side-effects and not SSA values.
-
- Example:
-
- ```mlir
- sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
- : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
- ```
-
- ```mlir
- sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
- { alg=1 : index}
- : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
- ```
- }];
- let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
- "`:` type($xs) (`jointly` type($ys)^)?";
- let hasVerifier = 1;
-}
-
def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
- OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+ AffineMapAttr:$nx, OptionalAttr<IndexAttr>:$ny,
SparseTensorSortKindAttr:$algorithm)> {
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e71d2a8dd623a6a..3cd0847bdf73765 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1353,34 +1353,6 @@ LogicalResult SelectOp::verify() {
return success();
}
-LogicalResult SortOp::verify() {
- if (getXs().empty())
- return emitError("need at least one xs buffer.");
-
- std::optional<int64_t> n = getConstantIntValue(getN());
-
- Type xtp = getMemRefType(getXs().front()).getElementType();
- auto checkTypes = [&](ValueRange operands,
- bool checkEleType = true) -> LogicalResult {
- for (Value opnd : operands) {
- auto mtp = getMemRefType(opnd);
- const DynSize sh = mtp.getShape()[0];
- // We can't check the size of dynamic dimension at compile-time, but all
- // xs and ys should have a dimension not less than n at runtime.
- if (n && !ShapedType::isDynamic(sh) && sh < n.value())
- return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
- ": {0} < {1}",
- sh, n.value()));
-
- if (checkEleType && xtp != mtp.getElementType())
- return emitError("mismatch xs element types");
- }
- return success();
- };
- RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
- return n ? checkTypes(getYs(), false) : success();
-}
-
LogicalResult SortCooOp::verify() {
std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
@@ -1391,7 +1363,7 @@ LogicalResult SortCooOp::verify() {
uint64_t n = cn.value();
uint64_t nx = 1;
if (auto nxAttr = getNxAttr()) {
- nx = nxAttr.getInt();
+ nx = nxAttr.getAffineMap().getNumResults();
if (nx < 1)
emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 029ecb0708941fe..255c78f8b4eb61f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -46,29 +46,29 @@ static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
using FuncGeneratorType = function_ref<void(
- OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
+ OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, bool, uint32_t)>;
/// Constructs a function name with this format to facilitate quick sort:
-/// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
-/// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
+/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
+/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
- StringRef namePrefix, uint64_t nx,
+ StringRef namePrefix, AffineMap xPerm,
uint64_t ny, bool isCoo,
ValueRange operands) {
- nameOstream << namePrefix << nx << "_"
+ nameOstream << namePrefix << xPerm << "_"
<< getMemRefType(operands[xStartIdx]).getElementType();
if (isCoo)
nameOstream << "_coo_" << ny;
- uint64_t yBufferOffset = isCoo ? 1 : nx;
+ uint64_t yBufferOffset = isCoo ? 1 : xPerm.getNumResults();
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
nameOstream << "_" << getMemRefType(v).getElementType();
}
/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet. The
-/// parameters `nx` and `ny` tell the number of x and y values provided
+/// parameters `xPerm` and `ny` tell the number of x and y values provided
/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
//
@@ -79,12 +79,12 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
static FlatSymbolRefAttr
getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
TypeRange resultTypes, StringRef namePrefix,
- uint64_t nx, uint64_t ny, bool isCoo,
+ AffineMap xPerm, uint64_t ny, bool isCoo,
ValueRange operands, FuncGeneratorType createFunc,
uint32_t nTrailingP = 0) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
- getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+ getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, isCoo,
operands.drop_back(nTrailingP));
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +101,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
- createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
+ createFunc(builder, module, func, xPerm, ny, isCoo, nTrailingP);
}
return result;
@@ -110,15 +110,17 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInXs(
- OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+ OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+ uint64_t ny, bool isCoo,
+ function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
Value iOffset, jOffset;
if (isCoo) {
- Value cstep = constantIndex(builder, loc, nx + ny);
+ Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
}
- for (uint64_t k = 0; k < nx; k++) {
+ for (AffineExpr e : xPerm.getResults()) {
+ unsigned k = e.cast<AffineDimExpr>().getPosition();
scf::IfOp ifOp;
Value i, j, buffer;
if (isCoo) {
@@ -138,21 +140,30 @@ static void forEachIJPairInXs(
/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInAllBuffers(
- OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+ OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+ uint64_t ny, bool isCoo,
+ function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
- // Create code for the first (nx + ny) buffers. When isCoo==true, these
+ // Create code for the first (xPerm + ny) buffers. When isCoo==true, these
// logical buffers are all from the xy buffer of the sort_coo operator.
- forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+ SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
+ xPerm.getResults().end());
+ for (unsigned y = 0; y < ny; y++) {
+ exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
+ }
+ AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
+ assert(xyPerm.isPermutation());
+
+ forEachIJPairInXs(builder, loc, args, xyPerm, 0, isCoo, bodyBuilder);
- uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+ uint64_t numHandledBuffers = isCoo ? 1 : xPerm.getNumResults() + ny;
// Create code for the remaining buffers.
Value i = args[0];
Value j = args[1];
for (const auto &arg :
llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
- bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+ bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
}
}
@@ -168,7 +179,7 @@ static void forEachIJPairInAllBuffers(
// ...
// swap(yn[i], yn[j]);
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
- uint64_t nx, uint64_t ny, bool isCoo) {
+ AffineMap xPerm, uint64_t ny, bool isCoo) {
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -176,20 +187,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
builder.create<memref::StoreOp>(loc, vi, buffer, j);
};
- forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
+ forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, isCoo, swapOnePair);
}
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
- OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo,
+ OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+ uint64_t ny, bool isCoo,
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
compareBuilder) {
Value result;
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
bool isFirstDim = (k == 0);
- bool isLastDim = (k == nx - 1);
+ bool isLastDim = (k == xPerm.getNumResults() - 1);
Value val =
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
if (isFirstDim) {
@@ -202,7 +213,7 @@ static Value createInlinedCompareImplementation(
}
};
- forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
+ forEachIJPairInXs(builder, loc, args, xPerm, ny, isCoo, bodyBuilder);
builder.setInsertionPointAfterValue(result);
return result;
@@ -252,13 +263,14 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
// else if (x2[2] != x2[j]))
// and so on ...
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
- ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ ValueRange args, AffineMap xPerm,
+ uint64_t ny, bool isCoo,
+ uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
- return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
- createEqCompare);
+ return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
+ isCoo, createEqCompare);
}
/// Generates code to compare whether x[i] is less than x[j] and returns the
@@ -306,13 +318,14 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
// else
// and so on ...
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
- ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ ValueRange args, AffineMap xPerm,
+ uint64_t ny, bool isCoo,
+ uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
- return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
- createLessThanCompare);
+ return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
+ isCoo, createLessThanCompare);
}
/// Creates a function to use a binary search to find the insertion point for
@@ -329,8 +342,9 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
// return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ func::FuncOp func, AffineMap xPerm,
+ uint64_t ny, bool isCoo,
+ uint32_t nTrailingP = 0) {
// Binary search doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -368,11 +382,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
// Compare xs[p] < xs[mid].
SmallVector<Value> compareOperands{p, mid};
- uint64_t numXBuffers = isCoo ? 1 : nx;
+ uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
Value cond2 =
- createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+ createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
// Update lo and hi for the WhileOp as follows:
// if (xs[p] < xs[mid]))
// hi = mid;
@@ -394,7 +408,7 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
/// xs[i] == xs[p].
static std::pair<Value, Value>
createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
- ValueRange xs, Value i, Value p, uint64_t nx, uint64_t ny,
+ ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny,
bool isCoo, int step) {
Location loc = func.getLoc();
scf::WhileOp whileOp =
@@ -414,7 +428,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
}
compareOperands.append(xs.begin(), xs.end());
Value cond =
- createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+ createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after =
@@ -429,7 +443,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands[0] = i;
compareOperands[1] = p;
Value compareEq =
- createInlinedEqCompare(builder, loc, compareOperands, nx, ny, isCoo);
+ createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny, isCoo);
return std::make_pair(whileOp.getResult(0), compareEq);
}
@@ -438,7 +452,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
/// right after the swap instructions.
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
- uint64_t nx, uint64_t ny, bool isCoo,
+ AffineMap xPerm, uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands,
Value a, Value b) {
@@ -446,59 +460,59 @@ static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
compareOperands[0] = b;
compareOperands[1] = a;
Value cond =
- createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+ createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
swapOperands[0] = b;
swapOperands[1] = a;
- createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
return ifOp;
}
/// Creates code to insert the 3rd element to a list of two sorted elements.
-static void createInsert3rd(OpBuilder &builder, Location loc, uint64_t nx,
+static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
- scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
+ scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
swapOperands, compareOperands, v1, v2);
- createCompareThenSwap(builder, loc, nx, ny, isCoo, swapOperands,
+ createCompareThenSwap(builder, loc, xPerm, ny, isCoo, swapOperands,
compareOperands, v0, v1);
builder.setInsertionPointAfter(ifOp);
}
/// Creates code to sort 3 elements.
-static void createSort3(OpBuilder &builder, Location loc, uint64_t nx,
+static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
// Sort the first 2 elements.
scf::IfOp ifOp1 = createCompareThenSwap(
- builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0, v1);
+ builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0, v1);
builder.setInsertionPointAfter(ifOp1);
// Insert the 3th element.
- createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
+ createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands,
v0, v1, v2);
}
/// Creates code to sort 5 elements.
-static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
+static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
uint64_t ny, bool isCoo,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2, Value v3, Value v4) {
// Sort the first 3 elements.
- createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v0,
+ createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0,
v1, v2);
auto insert4th = [&]() {
scf::IfOp ifOp = createCompareThenSwap(
- builder, loc, nx, ny, isCoo, swapOperands, compareOperands, v2, v3);
- createInsert3rd(builder, loc, nx, ny, isCoo, swapOperands, compareOperands,
- v0, v1, v2);
+ builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v2, v3);
+ createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands,
+ compareOperands, v0, v1, v2);
builder.setInsertionPointAfter(ifOp);
};
@@ -506,7 +520,7 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
insert4th();
// Insert the 5th element.
- scf::IfOp ifOp = createCompareThenSwap(builder, loc, nx, ny, isCoo,
+ scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
swapOperands, compareOperands, v3, v4);
insert4th();
builder.setInsertionPointAfter(ifOp);
@@ -517,11 +531,11 @@ static void createSort5(OpBuilder &builder, Location loc, uint64_t nx,
/// the number of values in range [lo, hi) is more than a threshold, we also
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
+ func::FuncOp func, AffineMap xPerm, uint64_t ny,
bool isCoo, Value lo, Value hi, Value mi,
ValueRange args) {
SmallVector<Value> compareOperands{mi, lo};
- uint64_t numXBuffers = isCoo ? 1 : nx;
+ uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
SmallVector<Value> swapOperands{mi, lo};
@@ -537,7 +551,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// When len < 1000, choose pivot from median of 3 values.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
- createSort3(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo,
+ createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
mi, hi);
// When len >= 1000, choose pivot from median of 5 values.
@@ -549,8 +563,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
// Value b is the middle between [mi, hi].
b = builder.create<arith::ShRUIOp>(loc, b, c1);
- createSort5(builder, loc, nx, ny, isCoo, swapOperands, compareOperands, lo, a,
- mi, b, hi);
+ createSort5(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
+ a, mi, b, hi);
builder.setInsertionPointAfter(lenIf);
}
@@ -586,7 +600,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// }
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
+ func::FuncOp func, AffineMap xPerm, uint64_t ny,
bool isCoo, uint32_t nTrailingP = 0) {
// Quick sort partition doesn't use trailing parameters.
(void)nTrailingP;
@@ -606,7 +620,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
- createChoosePivot(builder, module, func, nx, ny, isCoo, i, j, p, args);
+ createChoosePivot(builder, module, func, xPerm, ny, isCoo, i, j, p, args);
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
@@ -628,14 +642,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
j = after->getArgument(1);
p = after->getArgument(2);
- uint64_t numXBuffers = isCoo ? 1 : nx;
+ uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
auto [iresult, iCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
- i, p, nx, ny, isCoo, 1);
+ i, p, xPerm, ny, isCoo, 1);
i = iresult;
auto [jresult, jCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
- j, p, nx, ny, isCoo, -1);
+ j, p, xPerm, ny, isCoo, -1);
j = jresult;
// If i < j:
@@ -645,7 +659,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> swapOperands{i, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
// If the pivot is moved, update p with the new pivot.
Value icond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -737,7 +751,7 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
// }
//
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
+ func::FuncOp func, AffineMap xPerm, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
// The value n is passed in as a trailing parameter.
assert(nTrailingP == 1);
@@ -768,7 +782,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
Value c1 = constantIndex(builder, loc, 1);
SmallVector<Value> compareOperands{start, start};
- uint64_t numXBuffers = isCoo ? 1 : nx;
+ uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
@@ -794,7 +808,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
compareOperands[0] = lChildIdx;
compareOperands[1] = rChildIdx;
Value cond2 =
- createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+ createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
scf::IfOp if2 =
builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
builder.setInsertionPointToStart(&if2.getThenRegion().front());
@@ -826,7 +840,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
compareOperands[0] = start;
compareOperands[1] = childIdx;
Value cond =
- createInlinedLessThan(builder, loc, compareOperands, nx, ny, isCoo);
+ createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
// The after-region of the WhileOp.
@@ -836,7 +850,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
childIdx = after->getArgument(2);
SmallVector<Value> swapOperands{start, childIdx};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
start = childIdx;
Value cond2 =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
@@ -869,7 +883,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
// shiftdown(lo, lo, l-1)
// }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
+ func::FuncOp func, AffineMap xPerm, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
// Heap sort function doesn't have trailing parameters.
(void)nTrailingP;
@@ -897,7 +911,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
shiftDownOperands.append(args.begin() + xStartIdx, args.end());
shiftDownOperands.push_back(n);
FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kShiftDownFuncNamePrefix, nx, ny, isCoo,
+ builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, isCoo,
shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
shiftDownOperands);
@@ -912,7 +926,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
SmallVector<Value> swapOperands{lo, loplm1};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands, nx, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
shiftDownOperands[1] = lo;
shiftDownOperands[shiftDownOperands.size() - 1] =
builder.create<arith::SubIOp>(loc, l, c1);
@@ -928,7 +942,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
/// the bigger partition to be processed by the enclosed while-loop.
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
- ValueRange args, uint64_t nx, uint64_t ny, bool isCoo,
+ ValueRange args, AffineMap xPerm, uint64_t ny, bool isCoo,
uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
@@ -937,7 +951,7 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
SmallVector<Type, 2> types(2, lo.getType()); // Only two types.
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
- builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, nx,
+ builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
Value p = builder
.create<func::CallOp>(loc, partitionFunc,
@@ -1008,8 +1022,8 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
// }
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP) {
+ func::FuncOp func, AffineMap xPerm,
+ uint64_t ny, bool isCoo, uint32_t nTrailingP) {
// Stable sort function doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -1034,8 +1048,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
SmallVector<Value> operands{lo, i};
operands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
- builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, nx,
- ny, isCoo, operands, createBinarySearchFunc);
+ builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
+ xPerm, ny, isCoo, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
@@ -1045,7 +1059,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands[0] = operands[1] = i;
SmallVector<Value> d;
forEachIJPairInAllBuffers(
- builder, loc, operands, nx, ny, isCoo,
+ builder, loc, operands, xPerm, ny, isCoo,
[&](uint64_t unused, Value i, Value unused2, Value buffer) {
d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
});
@@ -1061,7 +1075,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands[1] = imj;
operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
forEachIJPairInAllBuffers(
- builder, loc, operands, nx, ny, isCoo,
+ builder, loc, operands, xPerm, ny, isCoo,
[&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
builder.create<memref::StoreOp>(loc, t, buffer, imj);
@@ -1071,7 +1085,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointAfter(forOpJ);
operands[0] = operands[1] = p;
forEachIJPairInAllBuffers(
- builder, loc, operands, nx, ny, isCoo,
+ builder, loc, operands, xPerm, ny, isCoo,
[&](uint64_t k, Value p, Value usused, Value buffer) {
builder.create<memref::StoreOp>(loc, d[k], buffer, p);
});
@@ -1123,7 +1137,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
// }
//
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
+ func::FuncOp func, AffineMap xPerm, uint64_t ny,
bool isCoo, uint32_t nTrailingP) {
assert(nTrailingP == 1 || nTrailingP == 0);
bool isHybrid = (nTrailingP == 1);
@@ -1173,7 +1187,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When len <= limit.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kSortStableFuncNamePrefix, nx, ny, isCoo,
+ builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, isCoo,
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
@@ -1193,7 +1207,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When depth exceeds limit.
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kHeapSortFuncNamePrefix, nx, ny, isCoo,
+ builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, isCoo,
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
@@ -1202,8 +1216,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When depth doesn't exceed limit.
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
args.back() = depthLimit;
- std::tie(lo, hi) =
- createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+ std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
+ isCoo, nTrailingP);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
builder.setInsertionPointAfter(depthIf);
@@ -1215,8 +1229,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
lo = lenIf.getResult(0);
hi = lenIf.getResult(1);
} else {
- std::tie(lo, hi) =
- createQuickSort(builder, module, func, args, nx, ny, isCoo, nTrailingP);
+ std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
+ isCoo, nTrailingP);
}
// New [lo, hi) for the next while-loop iteration.
@@ -1229,7 +1243,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
-LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
+LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
uint64_t ny, bool isCoo,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
@@ -1284,9 +1298,9 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
break;
}
- FlatSymbolRefAttr func =
- getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
- ny, isCoo, operands, funcGenerator, nTrailingP);
+ FlatSymbolRefAttr func = getMangledSortHelperFunc(
+ rewriter, insertPoint, TypeRange(), funcName, xPerm, ny, isCoo, operands,
+ funcGenerator, nTrailingP);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
@@ -1410,20 +1424,6 @@ struct PushBackRewriter : OpRewritePattern<PushBackOp> {
bool enableBufferInitialization;
};
-/// Sparse rewriting rule for the sort operator.
-struct SortRewriter : public OpRewritePattern<SortOp> {
-public:
- using OpRewritePattern<SortOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(SortOp op,
- PatternRewriter &rewriter) const override {
- SmallVector<Value> xys(op.getXs());
- xys.append(op.getYs().begin(), op.getYs().end());
- return matchAndRewriteSortOp(op, xys, op.getXs().size(), /*ny=*/0,
- /*isCoo=*/false, rewriter);
- }
-};
-
/// Sparse rewriting rule for the sort_coo operator.
struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
public:
@@ -1434,15 +1434,13 @@ struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
SmallVector<Value> xys;
xys.push_back(op.getXy());
xys.append(op.getYs().begin(), op.getYs().end());
- uint64_t nx = 1;
- if (auto nxAttr = op.getNxAttr())
- nx = nxAttr.getInt();
+ auto xPerm = op.getNx();
uint64_t ny = 0;
if (auto nyAttr = op.getNyAttr())
ny = nyAttr.getInt();
- return matchAndRewriteSortOp(op, xys, nx, ny,
+ return matchAndRewriteSortOp(op, xys, xPerm, ny,
/*isCoo=*/true, rewriter);
}
};
@@ -1457,5 +1455,5 @@ void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns,
bool enableBufferInitialization) {
patterns.add<PushBackRewriter>(patterns.getContext(),
enableBufferInitialization);
- patterns.add<SortRewriter, SortCooRewriter>(patterns.getContext());
+ patterns.add<SortCooRewriter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 557c5c471c4a77c..4419c39c69927e9 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -890,8 +890,9 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
// If the innermost level is ordered, we need to sort the coordinates
// in the "added" array prior to applying the compression.
if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
- rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
- SparseTensorSortKind::HybridQuickSort);
+ rewriter.create<SortCooOp>(
+ loc, count, added, ValueRange{}, rewriter.getMultiDimIdentityMap(1),
+ rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
// While performing the insertions, we also need to reset the elements
// of the values/filled-switch by only iterating over the set elements,
// to ensure that the runtime complexity remains proportional to the
@@ -1486,9 +1487,10 @@ struct SparseNewOpConverter : public OpConversionPattern<NewOp> {
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, notSorted, /*else*/ false);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
- rewriter.create<SortCooOp>(
- loc, nse, xs, ValueRange{ys}, rewriter.getIndexAttr(lvlRank),
- rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
+ auto xPerm = rewriter.getMultiDimIdentityMap(lvlRank);
+ rewriter.create<SortCooOp>(loc, nse, xs, ValueRange{ys}, xPerm,
+ rewriter.getIndexAttr(0),
+ SparseTensorSortKind::HybridQuickSort);
rewriter.setInsertionPointAfter(ifOp);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index ca7d8a7850b0b19..7d2f0c7f139cda5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -207,7 +207,6 @@ struct SparseTensorCodegenPass
ConversionTarget target(*ctx);
// Most ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
- target.addLegalOp<SortOp>();
target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
// Storage specifier outlives sparse tensor pipeline.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 47f7dad08c8c920..277903dc55b7432 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1206,29 +1206,23 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
// Retrieve the values-array.
Value y = genToValues(rewriter, loc, src);
const auto encSrc = srcTp.getEncoding();
- // Sort the COO tensor so that its elements are ordered via increasing
- // coordinates for the storage ordering of the dst tensor. Use SortCoo
- // if the COO tensor has the same ordering as the dst tensor.
- if (dimRank > 1 && srcTp.hasSameDimToLvl(dstTp)) {
- Value xs = genToCoordinatesBuffer(rewriter, loc, src);
- rewriter.create<SortCooOp>(
- loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(dimRank),
- rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
- } else {
- // Gather the coordinates-arrays in the dst tensor storage order.
- SmallVector<Value> xs(dstLvlRank);
- const Level srcLvlRank = srcTp.getLvlRank();
- for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
- // FIXME: `toOrigDim` is deprecated
- Dimension dim = toOrigDim(encSrc, srcLvl);
- // FIXME: `toStoredDim` is deprecated
- Level dstLvl = toStoredDim(encDst, dim);
- xs[dstLvl] =
- genToCoordinates(rewriter, loc, src, srcLvl, /*cooStart=*/0);
- }
- rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y},
- SparseTensorSortKind::HybridQuickSort);
+ // Builds the dstLvl -> srcLvl permutation maps.
+ SmallVector<AffineExpr> es(dstLvlRank);
+ const Level srcLvlRank = srcTp.getLvlRank();
+ for (Level srcLvl = 0; srcLvl < srcLvlRank; srcLvl++) {
+ // FIXME: `toOrigDim` is deprecated
+ Dimension dim = toOrigDim(encSrc, srcLvl);
+ // FIXME: `toStoredDim` is deprecated
+ Level dstLvl = toStoredDim(encDst, dim);
+ es[dstLvl] = rewriter.getAffineDimExpr(srcLvl);
}
+ auto xPerm = AffineMap::get(dstLvlRank, 0, es, rewriter.getContext());
+ assert(xPerm.isPermutation()); // must be a permutation.
+
+ Value xs = genToCoordinatesBuffer(rewriter, loc, src);
+ rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y}, xPerm,
+ rewriter.getIndexAttr(0),
+ SparseTensorSortKind::HybridQuickSort);
}
// For each element in the COO tensor, insert the element to the dst tensor.
>From 6bbb3ba6543139af10d9dc86b6728495c6e77b73 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Sep 2023 22:10:23 +0000
Subject: [PATCH 2/4] remove deadcode after cleanup
---
.../Transforms/SparseBufferRewriting.cpp | 235 ++++++++----------
1 file changed, 103 insertions(+), 132 deletions(-)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 255c78f8b4eb61f..7011578a5afa0b7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -45,23 +45,20 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
-using FuncGeneratorType = function_ref<void(
- OpBuilder &, ModuleOp, func::FuncOp, AffineMap, uint64_t, bool, uint32_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+ AffineMap, uint64_t, uint32_t)>;
/// Constructs a function name with this format to facilitate quick sort:
/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
StringRef namePrefix, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- ValueRange operands) {
+ uint64_t ny, ValueRange operands) {
nameOstream << namePrefix << xPerm << "_"
<< getMemRefType(operands[xStartIdx]).getElementType();
+ nameOstream << "_coo_" << ny;
- if (isCoo)
- nameOstream << "_coo_" << ny;
-
- uint64_t yBufferOffset = isCoo ? 1 : xPerm.getNumResults();
+ constexpr uint64_t yBufferOffset = 1;
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
nameOstream << "_" << getMemRefType(v).getElementType();
}
@@ -69,22 +66,19 @@ static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet. The
/// parameters `xPerm` and `ny` tell the number of x and y values provided
-/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
-/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+/// by the buffer in xStartIdx.
//
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
// parameters for the sorting functions. Other parameters, such as the recursive
// call depth, are appended to the end of the parameter list as
// "trailing parameters".
-static FlatSymbolRefAttr
-getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
- TypeRange resultTypes, StringRef namePrefix,
- AffineMap xPerm, uint64_t ny, bool isCoo,
- ValueRange operands, FuncGeneratorType createFunc,
- uint32_t nTrailingP = 0) {
+static FlatSymbolRefAttr getMangledSortHelperFunc(
+ OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
+ StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
+ FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
- getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny, isCoo,
+ getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
operands.drop_back(nTrailingP));
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +95,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
- createFunc(builder, module, func, xPerm, ny, isCoo, nTrailingP);
+ createFunc(builder, module, func, xPerm, ny, nTrailingP);
}
return result;
@@ -111,28 +105,20 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInXs(
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
- uint64_t ny, bool isCoo,
+ uint64_t ny,
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
- Value iOffset, jOffset;
- if (isCoo) {
- Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
- iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
- jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
- }
+ Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
+ Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+ Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
for (AffineExpr e : xPerm.getResults()) {
unsigned k = e.cast<AffineDimExpr>().getPosition();
scf::IfOp ifOp;
Value i, j, buffer;
- if (isCoo) {
- Value ck = constantIndex(builder, loc, k);
- i = builder.create<arith::AddIOp>(loc, ck, iOffset);
- j = builder.create<arith::AddIOp>(loc, ck, jOffset);
- buffer = args[xStartIdx];
- } else {
- i = args[0];
- j = args[1];
- buffer = args[xStartIdx + k];
- }
+ Value ck = constantIndex(builder, loc, k);
+ i = builder.create<arith::AddIOp>(loc, ck, iOffset);
+ j = builder.create<arith::AddIOp>(loc, ck, jOffset);
+ buffer = args[xStartIdx];
+
bodyBuilder(k, i, j, buffer);
}
}
@@ -141,11 +127,10 @@ static void forEachIJPairInXs(
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInAllBuffers(
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
- uint64_t ny, bool isCoo,
+ uint64_t ny,
function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
- // Create code for the first (xPerm + ny) buffers. When isCoo==true, these
- // logical buffers are all from the xy buffer of the sort_coo operator.
+ // Create code for the first (xPerm + ny) buffers.
SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
xPerm.getResults().end());
for (unsigned y = 0; y < ny; y++) {
@@ -154,10 +139,9 @@ static void forEachIJPairInAllBuffers(
AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
assert(xyPerm.isPermutation());
- forEachIJPairInXs(builder, loc, args, xyPerm, 0, isCoo, bodyBuilder);
-
- uint64_t numHandledBuffers = isCoo ? 1 : xPerm.getNumResults() + ny;
+ forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
+ constexpr uint64_t numHandledBuffers = 1;
// Create code for the remaining buffers.
Value i = args[0];
Value j = args[1];
@@ -179,7 +163,7 @@ static void forEachIJPairInAllBuffers(
// ...
// swap(yn[i], yn[j]);
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
- AffineMap xPerm, uint64_t ny, bool isCoo) {
+ AffineMap xPerm, uint64_t ny) {
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -187,14 +171,14 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
builder.create<memref::StoreOp>(loc, vi, buffer, j);
};
- forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, isCoo, swapOnePair);
+ forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
}
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
- uint64_t ny, bool isCoo,
+ uint64_t ny,
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
compareBuilder) {
Value result;
@@ -213,7 +197,7 @@ static Value createInlinedCompareImplementation(
}
};
- forEachIJPairInXs(builder, loc, args, xPerm, ny, isCoo, bodyBuilder);
+ forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
builder.setInsertionPointAfterValue(result);
return result;
@@ -264,13 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
// and so on ...
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
ValueRange args, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- uint32_t nTrailingP = 0) {
+ uint64_t ny, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
- isCoo, createEqCompare);
+ createEqCompare);
}
/// Generates code to compare whether x[i] is less than x[j] and returns the
@@ -319,13 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
// and so on ...
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
ValueRange args, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- uint32_t nTrailingP = 0) {
+ uint64_t ny, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
- isCoo, createLessThanCompare);
+ createLessThanCompare);
}
/// Creates a function to use a binary search to find the insertion point for
@@ -343,8 +325,7 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- uint32_t nTrailingP = 0) {
+ uint64_t ny, uint32_t nTrailingP = 0) {
// Binary search doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -382,11 +363,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
// Compare xs[p] < xs[mid].
SmallVector<Value> compareOperands{p, mid};
- uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+ constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
- Value cond2 =
- createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+ Value cond2 = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
// Update lo and hi for the WhileOp as follows:
// if (xs[p] < xs[mid]))
// hi = mid;
@@ -406,10 +386,11 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
/// while (xs[i] > xs[p]) i += step (step < 0)
/// The routine returns i as well as a boolean value to indicate whether
/// xs[i] == xs[p].
-static std::pair<Value, Value>
-createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
- ValueRange xs, Value i, Value p, AffineMap xPerm, uint64_t ny,
- bool isCoo, int step) {
+static std::pair<Value, Value> createScanLoop(OpBuilder &builder,
+ ModuleOp module,
+ func::FuncOp func, ValueRange xs,
+ Value i, Value p, AffineMap xPerm,
+ uint64_t ny, int step) {
Location loc = func.getLoc();
scf::WhileOp whileOp =
builder.create<scf::WhileOp>(loc, TypeRange{i.getType()}, ValueRange{i});
@@ -427,8 +408,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands.push_back(before->getArgument(0));
}
compareOperands.append(xs.begin(), xs.end());
- Value cond =
- createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+ Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
Block *after =
@@ -443,7 +423,7 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
compareOperands[0] = i;
compareOperands[1] = p;
Value compareEq =
- createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny, isCoo);
+ createInlinedEqCompare(builder, loc, compareOperands, xPerm, ny);
return std::make_pair(whileOp.getResult(0), compareEq);
}
@@ -452,67 +432,63 @@ createScanLoop(OpBuilder &builder, ModuleOp module, func::FuncOp func,
/// if compareFunc(data[b], data[a]) returns true. The new insertion point is
/// right after the swap instructions.
static scf::IfOp createCompareThenSwap(OpBuilder &builder, Location loc,
- AffineMap xPerm, uint64_t ny, bool isCoo,
+ AffineMap xPerm, uint64_t ny,
SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands,
Value a, Value b) {
// Compare(data[b], data[a]).
compareOperands[0] = b;
compareOperands[1] = a;
- Value cond =
- createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+ Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
scf::IfOp ifOp = builder.create<scf::IfOp>(loc, cond, /*else=*/false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
swapOperands[0] = b;
swapOperands[1] = a;
- createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny);
return ifOp;
}
/// Creates code to insert the 3rd element to a list of two sorted elements.
static void createInsert3rd(OpBuilder &builder, Location loc, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- SmallVectorImpl<Value> &swapOperands,
+ uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
- scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
- swapOperands, compareOperands, v1, v2);
- createCompareThenSwap(builder, loc, xPerm, ny, isCoo, swapOperands,
- compareOperands, v0, v1);
+ scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
+ compareOperands, v1, v2);
+ createCompareThenSwap(builder, loc, xPerm, ny, swapOperands, compareOperands,
+ v0, v1);
builder.setInsertionPointAfter(ifOp);
}
/// Creates code to sort 3 elements.
static void createSort3(OpBuilder &builder, Location loc, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- SmallVectorImpl<Value> &swapOperands,
+ uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2) {
// Sort the first 2 elements.
- scf::IfOp ifOp1 = createCompareThenSwap(
- builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0, v1);
+ scf::IfOp ifOp1 = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
+ compareOperands, v0, v1);
builder.setInsertionPointAfter(ifOp1);
// Insert the 3th element.
- createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands,
- v0, v1, v2);
+ createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
+ v1, v2);
}
/// Creates code to sort 5 elements.
static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- SmallVectorImpl<Value> &swapOperands,
+ uint64_t ny, SmallVectorImpl<Value> &swapOperands,
SmallVectorImpl<Value> &compareOperands, Value v0,
Value v1, Value v2, Value v3, Value v4) {
// Sort the first 3 elements.
- createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v0,
- v1, v2);
+ createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, v0, v1,
+ v2);
auto insert4th = [&]() {
scf::IfOp ifOp = createCompareThenSwap(
- builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, v2, v3);
- createInsert3rd(builder, loc, xPerm, ny, isCoo, swapOperands,
- compareOperands, v0, v1, v2);
+ builder, loc, xPerm, ny, swapOperands, compareOperands, v2, v3);
+ createInsert3rd(builder, loc, xPerm, ny, swapOperands, compareOperands, v0,
+ v1, v2);
builder.setInsertionPointAfter(ifOp);
};
@@ -520,8 +496,8 @@ static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
insert4th();
// Insert the 5th element.
- scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, isCoo,
- swapOperands, compareOperands, v3, v4);
+ scf::IfOp ifOp = createCompareThenSwap(builder, loc, xPerm, ny, swapOperands,
+ compareOperands, v3, v4);
insert4th();
builder.setInsertionPointAfter(ifOp);
}
@@ -532,10 +508,9 @@ static void createSort5(OpBuilder &builder, Location loc, AffineMap xPerm,
/// include the middle of [lo, mi) and [mi, hi) and sort a total of five values.
static void createChoosePivot(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
- bool isCoo, Value lo, Value hi, Value mi,
- ValueRange args) {
+ Value lo, Value hi, Value mi, ValueRange args) {
SmallVector<Value> compareOperands{mi, lo};
- uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+ constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
SmallVector<Value> swapOperands{mi, lo};
@@ -551,8 +526,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// When len < 1000, choose pivot from median of 3 values.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
- createSort3(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
- mi, hi);
+ createSort3(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, mi,
+ hi);
// When len >= 1000, choose pivot from median of 5 values.
builder.setInsertionPointToStart(&lenIf.getElseRegion().front());
@@ -563,8 +538,8 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
Value b = builder.create<arith::AddIOp>(loc, mi, hiP1);
// Value b is the middle between [mi, hi].
b = builder.create<arith::ShRUIOp>(loc, b, c1);
- createSort5(builder, loc, xPerm, ny, isCoo, swapOperands, compareOperands, lo,
- a, mi, b, hi);
+ createSort5(builder, loc, xPerm, ny, swapOperands, compareOperands, lo, a, mi,
+ b, hi);
builder.setInsertionPointAfter(lenIf);
}
@@ -601,7 +576,7 @@ static void createChoosePivot(OpBuilder &builder, ModuleOp module,
// }
static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ uint32_t nTrailingP = 0) {
// Quick sort partition doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -620,7 +595,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
Value i = lo;
Value j = builder.create<arith::SubIOp>(loc, hi, c1);
- createChoosePivot(builder, module, func, xPerm, ny, isCoo, i, j, p, args);
+ createChoosePivot(builder, module, func, xPerm, ny, i, j, p, args);
Value trueVal = constantI1(builder, loc, true); // The value for while (true)
SmallVector<Value, 4> operands{i, j, p, trueVal}; // Exactly four values.
SmallVector<Type, 4> types{i.getType(), j.getType(), p.getType(),
@@ -642,14 +617,14 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
j = after->getArgument(1);
p = after->getArgument(2);
- uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+ constexpr uint64_t numXBuffers = 1;
auto [iresult, iCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
- i, p, xPerm, ny, isCoo, 1);
+ i, p, xPerm, ny, 1);
i = iresult;
auto [jresult, jCompareEq] =
createScanLoop(builder, module, func, args.slice(xStartIdx, numXBuffers),
- j, p, xPerm, ny, isCoo, -1);
+ j, p, xPerm, ny, -1);
j = jresult;
// If i < j:
@@ -659,7 +634,7 @@ static void createPartitionFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
SmallVector<Value> swapOperands{i, j};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny);
// If the pivot is moved, update p with the new pivot.
Value icond =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, i, p);
@@ -752,7 +727,7 @@ static Value createSubTwoDividedByTwo(OpBuilder &builder, Location loc,
//
static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
- bool isCoo, uint32_t nTrailingP) {
+ uint32_t nTrailingP) {
// The value n is passed in as a trailing parameter.
assert(nTrailingP == 1);
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -782,7 +757,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointToStart(&ifNc.getThenRegion().front());
Value c1 = constantIndex(builder, loc, 1);
SmallVector<Value> compareOperands{start, start};
- uint64_t numXBuffers = isCoo ? 1 : xPerm.getNumResults();
+ constexpr uint64_t numXBuffers = 1;
compareOperands.append(args.begin() + xStartIdx,
args.begin() + xStartIdx + numXBuffers);
@@ -808,7 +783,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
compareOperands[0] = lChildIdx;
compareOperands[1] = rChildIdx;
Value cond2 =
- createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+ createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
scf::IfOp if2 =
builder.create<scf::IfOp>(loc, ifTypes, cond2, /*else=*/true);
builder.setInsertionPointToStart(&if2.getThenRegion().front());
@@ -839,8 +814,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
childIdx = before->getArgument(2);
compareOperands[0] = start;
compareOperands[1] = childIdx;
- Value cond =
- createInlinedLessThan(builder, loc, compareOperands, xPerm, ny, isCoo);
+ Value cond = createInlinedLessThan(builder, loc, compareOperands, xPerm, ny);
builder.create<scf::ConditionOp>(loc, cond, before->getArguments());
// The after-region of the WhileOp.
@@ -850,7 +824,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
childIdx = after->getArgument(2);
SmallVector<Value> swapOperands{start, childIdx};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny);
start = childIdx;
Value cond2 =
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, t, child);
@@ -884,7 +858,7 @@ static void createShiftDownFunc(OpBuilder &builder, ModuleOp module,
// }
static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
- bool isCoo, uint32_t nTrailingP) {
+ uint32_t nTrailingP) {
// Heap sort function doesn't have trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -911,7 +885,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
shiftDownOperands.append(args.begin() + xStartIdx, args.end());
shiftDownOperands.push_back(n);
FlatSymbolRefAttr shiftDownFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny, isCoo,
+ builder, func, TypeRange(), kShiftDownFuncNamePrefix, xPerm, ny,
shiftDownOperands, createShiftDownFunc, /*nTrailingP=*/1);
builder.create<func::CallOp>(loc, shiftDownFunc, TypeRange(),
shiftDownOperands);
@@ -926,7 +900,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
loplm1 = builder.create<arith::SubIOp>(loc, loplm1, c1);
SmallVector<Value> swapOperands{lo, loplm1};
swapOperands.append(args.begin() + xStartIdx, args.end());
- createSwap(builder, loc, swapOperands, xPerm, ny, isCoo);
+ createSwap(builder, loc, swapOperands, xPerm, ny);
shiftDownOperands[1] = lo;
shiftDownOperands[shiftDownOperands.size() - 1] =
builder.create<arith::SubIOp>(loc, l, c1);
@@ -942,7 +916,7 @@ static void createHeapSortFunc(OpBuilder &builder, ModuleOp module,
/// the bigger partition to be processed by the enclosed while-loop.
static std::pair<Value, Value>
createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
- ValueRange args, AffineMap xPerm, uint64_t ny, bool isCoo,
+ ValueRange args, AffineMap xPerm, uint64_t ny,
uint32_t nTrailingP) {
MLIRContext *context = module.getContext();
Location loc = func.getLoc();
@@ -952,7 +926,7 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm,
- ny, isCoo, args.drop_back(nTrailingP), createPartitionFunc);
+ ny, args.drop_back(nTrailingP), createPartitionFunc);
Value p = builder
.create<func::CallOp>(loc, partitionFunc,
TypeRange{IndexType::get(context)},
@@ -1023,7 +997,7 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func,
// }
static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm,
- uint64_t ny, bool isCoo, uint32_t nTrailingP) {
+ uint64_t ny, uint32_t nTrailingP) {
// Stable sort function doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -1049,7 +1023,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands.append(args.begin() + xStartIdx, args.end());
FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc(
builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix,
- xPerm, ny, isCoo, operands, createBinarySearchFunc);
+ xPerm, ny, operands, createBinarySearchFunc);
Value p = builder
.create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()},
operands)
@@ -1059,7 +1033,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands[0] = operands[1] = i;
SmallVector<Value> d;
forEachIJPairInAllBuffers(
- builder, loc, operands, xPerm, ny, isCoo,
+ builder, loc, operands, xPerm, ny,
[&](uint64_t unused, Value i, Value unused2, Value buffer) {
d.push_back(builder.create<memref::LoadOp>(loc, buffer, i));
});
@@ -1075,7 +1049,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
operands[1] = imj;
operands[0] = builder.create<arith::SubIOp>(loc, imj, c1);
forEachIJPairInAllBuffers(
- builder, loc, operands, xPerm, ny, isCoo,
+ builder, loc, operands, xPerm, ny,
[&](uint64_t unused, Value imjm1, Value imj, Value buffer) {
Value t = builder.create<memref::LoadOp>(loc, buffer, imjm1);
builder.create<memref::StoreOp>(loc, t, buffer, imj);
@@ -1085,7 +1059,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
builder.setInsertionPointAfter(forOpJ);
operands[0] = operands[1] = p;
forEachIJPairInAllBuffers(
- builder, loc, operands, xPerm, ny, isCoo,
+ builder, loc, operands, xPerm, ny,
[&](uint64_t k, Value p, Value usused, Value buffer) {
builder.create<memref::StoreOp>(loc, d[k], buffer, p);
});
@@ -1138,7 +1112,7 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module,
//
static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
func::FuncOp func, AffineMap xPerm, uint64_t ny,
- bool isCoo, uint32_t nTrailingP) {
+ uint32_t nTrailingP) {
assert(nTrailingP == 1 || nTrailingP == 0);
bool isHybrid = (nTrailingP == 1);
OpBuilder::InsertionGuard insertionGuard(builder);
@@ -1187,7 +1161,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When len <= limit.
builder.setInsertionPointToStart(&lenIf.getThenRegion().front());
FlatSymbolRefAttr insertionSortFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny, isCoo,
+ builder, func, TypeRange(), kSortStableFuncNamePrefix, xPerm, ny,
ValueRange(args).drop_back(nTrailingP), createSortStableFunc);
builder.create<func::CallOp>(loc, insertionSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
@@ -1207,7 +1181,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When depth exceeds limit.
builder.setInsertionPointToStart(&depthIf.getThenRegion().front());
FlatSymbolRefAttr heapSortFunc = getMangledSortHelperFunc(
- builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny, isCoo,
+ builder, func, TypeRange(), kHeapSortFuncNamePrefix, xPerm, ny,
ValueRange(args).drop_back(nTrailingP), createHeapSortFunc);
builder.create<func::CallOp>(loc, heapSortFunc, TypeRange(),
ValueRange(args).drop_back(nTrailingP));
@@ -1216,8 +1190,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
// When depth doesn't exceed limit.
builder.setInsertionPointToStart(&depthIf.getElseRegion().front());
args.back() = depthLimit;
- std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
- isCoo, nTrailingP);
+ std::tie(lo, hi) =
+ createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
builder.create<scf::YieldOp>(loc, ValueRange{lo, hi});
builder.setInsertionPointAfter(depthIf);
@@ -1229,8 +1203,8 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
lo = lenIf.getResult(0);
hi = lenIf.getResult(1);
} else {
- std::tie(lo, hi) = createQuickSort(builder, module, func, args, xPerm, ny,
- isCoo, nTrailingP);
+ std::tie(lo, hi) =
+ createQuickSort(builder, module, func, args, xPerm, ny, nTrailingP);
}
// New [lo, hi) for the next while-loop iteration.
@@ -1244,8 +1218,7 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module,
/// Implements the rewriting for operator sort and sort_coo.
template <typename OpTy>
LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
- uint64_t ny, bool isCoo,
- PatternRewriter &rewriter) {
+ uint64_t ny, PatternRewriter &rewriter) {
Location loc = op.getLoc();
SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()};
@@ -1298,9 +1271,9 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
break;
}
- FlatSymbolRefAttr func = getMangledSortHelperFunc(
- rewriter, insertPoint, TypeRange(), funcName, xPerm, ny, isCoo, operands,
- funcGenerator, nTrailingP);
+ FlatSymbolRefAttr func =
+ getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName,
+ xPerm, ny, operands, funcGenerator, nTrailingP);
rewriter.replaceOpWithNewOp<func::CallOp>(op, func, TypeRange(), operands);
return success();
}
@@ -1310,7 +1283,6 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm,
//===---------------------------------------------------------------------===//
namespace {
-
/// Sparse rewriting rule for the push_back operator.
struct PushBackRewriter : OpRewritePattern<PushBackOp> {
public:
@@ -1440,8 +1412,7 @@ struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
if (auto nyAttr = op.getNyAttr())
ny = nyAttr.getInt();
- return matchAndRewriteSortOp(op, xys, xPerm, ny,
- /*isCoo=*/true, rewriter);
+ return matchAndRewriteSortOp(op, xys, xPerm, ny, rewriter);
}
};
>From b21bc7ed97444ff95f8607895b14bb1198db3ad3 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Fri, 15 Sep 2023 23:13:46 +0000
Subject: [PATCH 3/4] update test cases.
---
.../Transforms/SparseBufferRewriting.cpp | 14 +-
.../SparseTensor/CPU/sparse_rewrite_sort.mlir | 187 ------------------
.../CPU/sparse_rewrite_sort_coo.mlir | 8 +-
3 files changed, 11 insertions(+), 198 deletions(-)
delete mode 100644 mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 7011578a5afa0b7..101bd165cc598b2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -110,14 +110,12 @@ static void forEachIJPairInXs(
Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
- for (AffineExpr e : xPerm.getResults()) {
- unsigned k = e.cast<AffineDimExpr>().getPosition();
- scf::IfOp ifOp;
- Value i, j, buffer;
- Value ck = constantIndex(builder, loc, k);
- i = builder.create<arith::AddIOp>(loc, ck, iOffset);
- j = builder.create<arith::AddIOp>(loc, ck, jOffset);
- buffer = args[xStartIdx];
+ for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
+ unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
+ Value ak = constantIndex(builder, loc, actualK);
+ Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
+ Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
+ Value buffer = args[xStartIdx];
bodyBuilder(k, i, j, buffer);
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
deleted file mode 100644
index 9e8ecad9cf282a2..000000000000000
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ /dev/null
@@ -1,187 +0,0 @@
-//--------------------------------------------------------------------------------------------------
-// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
-//
-// Set-up that's shared across all tests in this directory. In principle, this
-// config could be moved to lit.local.cfg. However, there are downstream users that
-// do not use these LIT config files. Hence why this is kept inline.
-//
-// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
-// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
-// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
-// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
-// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
-// DEFINE: %{run_opts} = -e entry -entry-point-result=void
-// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
-// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
-//
-// DEFINE: %{env} =
-//--------------------------------------------------------------------------------------------------
-
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
-// RUN: %{compile} | %{run} | FileCheck %s
-//
-// Do the same run, but now with vectorization.
-// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
-// RUN: %{compile} | %{run} | FileCheck %s
-//
-// Do the same run, but now with VLA vectorization.
-// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-
-module {
- func.func private @printMemref1dI32(%ptr : memref<?xi32>) attributes { llvm.emit_c_interface }
-
- // Stores 5 values to the memref buffer.
- func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
- %v3: i32, %v4: i32) -> () {
- %i0 = arith.constant 0 : index
- %i1 = arith.constant 1 : index
- %i2 = arith.constant 2 : index
- %i3 = arith.constant 3 : index
- %i4 = arith.constant 4 : index
- memref.store %v0, %b[%i0] : memref<?xi32>
- memref.store %v1, %b[%i1] : memref<?xi32>
- memref.store %v2, %b[%i2] : memref<?xi32>
- memref.store %v3, %b[%i3] : memref<?xi32>
- memref.store %v4, %b[%i4] : memref<?xi32>
- return
- }
-
- // The main driver.
- func.func @entry() {
- %c0 = arith.constant 0 : i32
- %c1 = arith.constant 1 : i32
- %c2 = arith.constant 2 : i32
- %c3 = arith.constant 3 : i32
- %c4 = arith.constant 4 : i32
- %c5 = arith.constant 5 : i32
- %c6 = arith.constant 6 : i32
- %c7 = arith.constant 7 : i32
- %c8 = arith.constant 8 : i32
- %c9 = arith.constant 9 : i32
- %c10 = arith.constant 10 : i32
- %c100 = arith.constant 100 : i32
-
- %i0 = arith.constant 0 : index
- %i4 = arith.constant 4 : index
- %i5 = arith.constant 5 : index
-
- // Prepare a buffer.
- %x0s = memref.alloc() : memref<5xi32>
- %x0 = memref.cast %x0s : memref<5xi32> to memref<?xi32>
- call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
-
- // Sort 0 elements.
- // Quick sort.
- // CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort quick_sort %i0, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- // Stable sort.
- // CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- // Heap sort.
- // CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort heap_sort %i0, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- // Hybrid sort.
- // CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-
- // Sort the first 4 elements, with the last valid value untouched.
- // Quick sort.
- // CHECK: [0, 2, 5, 10, 1]
- sparse_tensor.sort quick_sort %i4, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- // Stable sort.
- // CHECK: [0, 2, 5, 10, 1]
- call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- // Heap sort.
- // CHECK: [0, 2, 5, 10, 1]
- call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort heap_sort %i4, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- // Hybrid sort.
- // CHECK: [0, 2, 5, 10, 1]
- sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
-
- // Prepare more buffers of different dimensions.
- %x1s = memref.alloc() : memref<10xi32>
- %x1 = memref.cast %x1s : memref<10xi32> to memref<?xi32>
- %x2s = memref.alloc() : memref<6xi32>
- %x2 = memref.cast %x2s : memref<6xi32> to memref<?xi32>
- %y0s = memref.alloc() : memref<7xi32>
- %y0 = memref.cast %y0s : memref<7xi32> to memref<?xi32>
-
- // Sort "parallel arrays".
- // CHECK: [1, 1, 2, 5, 10]
- // CHECK: [3, 3, 1, 10, 1
- // CHECK: [9, 9, 4, 7, 2
- // CHECK: [7, 8, 10, 9, 6
- call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort quick_sort %i5, %x0, %x1, %x2 jointly %y0
- : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
- // Stable sort.
- // CHECK: [1, 1, 2, 5, 10]
- // CHECK: [3, 3, 1, 10, 1
- // CHECK: [9, 9, 4, 7, 2
- // CHECK: [8, 7, 10, 9, 6
- call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0
- : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
- // Heap sort.
- // CHECK: [1, 1, 2, 5, 10]
- // CHECK: [3, 3, 1, 10, 1
- // CHECK: [9, 9, 4, 7, 2
- // CHECK: [7, 8, 10, 9, 6
- call @storeValuesTo(%x0, %c10, %c2, %c1, %c5, %c1)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%x1, %c1, %c1, %c3, %c10, %c3)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%x2, %c2, %c4, %c9, %c7, %c9)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
- : (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort heap_sort %i5, %x0, %x1, %x2 jointly %y0
- : memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
- call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%x2) : (memref<?xi32>) -> ()
- call @printMemref1dI32(%y0) : (memref<?xi32>) -> ()
-
- // Release the buffers.
- memref.dealloc %x0 : memref<?xi32>
- memref.dealloc %x1 : memref<?xi32>
- memref.dealloc %x2 : memref<?xi32>
- memref.dealloc %y0 : memref<?xi32>
- return
- }
-}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index ca5dd00d02aff1e..0594b311184f4d5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -28,6 +28,8 @@
// Do the same run, but now with VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+#ID_MAP = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
module {
// Stores 5 values to the memref buffer.
func.func @storeValuesTo(%b: memref<?xi32>, %v0: i32, %v1: i32, %v2: i32,
@@ -109,7 +111,7 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v : vector<5xi32>
@@ -137,7 +139,7 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v2 : vector<5xi32>
@@ -165,7 +167,7 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v3 : vector<5xi32>
>From bb146d5ead3f29e832b0bf3d3c94046089727052 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Tue, 19 Sep 2023 00:06:40 +0000
Subject: [PATCH 4/4] fix check tests
---
.../SparseTensor/IR/SparseTensorOps.td | 24 ++--
.../SparseTensor/IR/SparseTensorDialect.cpp | 17 +--
.../Transforms/SparseBufferRewriting.cpp | 9 +-
.../SparseTensor/buffer_rewriting.mlir | 105 ++++--------------
mlir/test/Dialect/SparseTensor/codegen.mlir | 6 +-
.../SparseTensor/convert_sparse2sparse.mlir | 2 +-
mlir/test/Dialect/SparseTensor/invalid.mlir | 47 ++++----
mlir/test/Dialect/SparseTensor/roundtrip.mlir | 64 ++---------
.../SparseTensor/sparse_matmul_codegen.mlir | 2 +-
.../CPU/sparse_rewrite_sort_coo.mlir | 51 ++++-----
10 files changed, 109 insertions(+), 218 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index d83d1ba03feb848..59815fc755ee5f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -765,27 +765,29 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
- AffineMapAttr:$nx, OptionalAttr<IndexAttr>:$ny,
+ AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
SparseTensorSortKindAttr:$algorithm)> {
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
let description = [{
- Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
- `xs` values and some `ys` values are put in the linear buffer `xy`. The
- optional index attribute `nx` provides the number of `xs` values in `xy`.
- When `nx` is not explicitly specified, its value is 1. The optional index
- attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
- explicitly specified, its value is 0. This instruction supports a more
- efficient way to store the COO definition in sparse tensor type.
-
- The buffer xy should have a dimension not less than n * (nx + ny) while the
+ Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
+ that are put in a single linear buffer `xy`.
+ The affine map attribute `perm_map` specifies the permutation to be applied on
+ the `xs` before comparison, the rank of the permutation map
+ also specifies the number of `xs` values in `xy`.
+ The optional index attribute `ny` provides the number of `ys` values in `xy`.
+ When `ny` is not explicitly specified, its value is 0.
+ This instruction supports a more efficient way to store the COO definition
+ in sparse tensor type.
+
+ The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
buffers in `ys` should have a dimension not less than `n`. The behavior of
the operator is undefined if this condition is not met.
Example:
```mlir
- sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
+ sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
: memref<?xindex>
```
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3cd0847bdf73765..9675a61109477b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1354,6 +1354,14 @@ LogicalResult SelectOp::verify() {
}
LogicalResult SortCooOp::verify() {
+ AffineMap xPerm = getPermMap();
+ uint64_t nx = xPerm.getNumDims();
+ if (nx < 1)
+ emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+
+ if (!xPerm.isPermutation())
+ emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+
std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
@@ -1361,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
return success();
uint64_t n = cn.value();
- uint64_t nx = 1;
- if (auto nxAttr = getNxAttr()) {
- nx = nxAttr.getAffineMap().getNumResults();
- if (nx < 1)
- emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
- }
uint64_t ny = 0;
if (auto nyAttr = getNyAttr()) {
ny = nyAttr.getInt();
@@ -1381,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};
- checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+ checkDim(getXy(), n * (nx + ny),
+ "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
for (Value opnd : getYs()) {
checkDim(opnd, n, "Expected dimension(y) >= n");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 101bd165cc598b2..3181395a474cfec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -54,8 +54,11 @@ using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
StringRef namePrefix, AffineMap xPerm,
uint64_t ny, ValueRange operands) {
- nameOstream << namePrefix << xPerm << "_"
- << getMemRefType(operands[xStartIdx]).getElementType();
+ nameOstream << namePrefix;
+ for (auto res : xPerm.getResults())
+ nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
+
+ nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
nameOstream << "_coo_" << ny;
constexpr uint64_t yBufferOffset = 1;
@@ -1405,7 +1408,7 @@ struct SortCooRewriter : public OpRewritePattern<SortCooOp> {
xys.push_back(op.getXy());
xys.append(op.getYs().begin(), op.getYs().end());
- auto xPerm = op.getNx();
+ auto xPerm = op.getPermMap();
uint64_t ny = 0;
if (auto nyAttr = op.getNyAttr())
ny = nyAttr.getInt();
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 0036bd5c3310b97..c96a55aa1e8b2f3 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -75,123 +75,64 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// -----
-// CHECK-LABEL: func.func private @_sparse_partition_1_i8_f32_index
-// CHECK-LABEL: func.func private @_sparse_qsort_1_i8_f32_index
-// CHECK-LABEL: func.func @sparse_sort_1d2v_quick
-func.func @sparse_sort_1d2v_quick(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
- -> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
- sparse_tensor.sort quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
- return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting function now. We have integration test
-// to verify correctness of the generated code.
-//
-// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG: func.func private @_sparse_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL: func.func @sparse_sort_3d_quick
-func.func @sparse_sort_3d_quick(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
- sparse_tensor.sort quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
- return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting function now. We have integration test
-// to verify correctness of the generated code.
-//
-// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
-// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-DAG: func.func private @_sparse_partition_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG: func.func private @_sparse_hybrid_qsort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: i64) {
-// CHECK-LABEL: func.func @sparse_sort_3d_hybrid
-func.func @sparse_sort_3d_hybrid(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
- return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting functions. We have integration test to
-// verify correctness of the generated code.
-//
-// CHECK-DAG: func.func private @_sparse_binary_search_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) -> index {
-// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL: func.func @sparse_sort_3d_stable
-func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
- sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
- return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
-// CHECK-DAG: func.func private @_sparse_shift_down_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>, %arg5: index) {
-// CHECK-DAG: func.func private @_sparse_heap_sort_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
-// CHECK-LABEL: func.func @sparse_sort_3d_heap
-func.func @sparse_sort_3d_heap(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
- sparse_tensor.sort heap_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
- return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
-}
-
-// -----
-
-// Only check the generated supporting functions. We have integration test to
-// verify correctness of the generated code.
-//
-// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG: func.func private @_sparse_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_quick
func.func @sparse_sort_coo_quick(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
- sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ sparse_tensor.sort_coo quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
// -----
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
+
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
-// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
-// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
-// CHECK-DAG: func.func private @_sparse_partition_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG: func.func private @_sparse_hybrid_qsort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
+// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_partition_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_hybrid_qsort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: i64) {
// CHECK-LABEL: func.func @sparse_sort_coo_hybrid
func.func @sparse_sort_coo_hybrid(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
- sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
// -----
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
+
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
-// CHECK-DAG: func.func private @_sparse_binary_search_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
-// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_binary_search_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) -> index {
+// CHECK-DAG: func.func private @_sparse_sort_stable_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_stable
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
- sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
// -----
+#ID_MAP=affine_map<(d0, d1) -> (d0, d1)>
+
// Only check the generated supporting functions. We have integration test to
// verify correctness of the generated code.
//
-// CHECK-DAG: func.func private @_sparse_shift_down_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
-// CHECK-DAG: func.func private @_sparse_heap_sort_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
+// CHECK-DAG: func.func private @_sparse_shift_down_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>, %arg5: index) {
+// CHECK-DAG: func.func private @_sparse_heap_sort_0_1_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_heap
func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
- sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {perm_map = #ID_MAP, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index f1317f23d656848..ea11a98b76ec639 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -436,7 +436,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -484,7 +484,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[A12:.*]] = arith.constant 1 : index
// CHECK: %[[A13:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A7]], %[[A6]]
// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
@@ -712,7 +712,7 @@ func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) ->
// CHECK: %[[A33:.*]] = call @getSparseTensorReaderReadToBuffers0F32(%[[A5]], %[[A32]], %[[A14]], %[[A15]])
// CHECK: %[[A34:.*]] = arith.cmpi eq, %[[A33]], %[[A1]] : i1
// CHECK: scf.if %[[A34]] {
-// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {nx = 2 : index, ny = 0 : index} : memref<?xindex> jointly memref<?xf32>
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A10]], %[[A14]] jointly %[[A15]] {ny = 0 : index, perm_map = #{{.*}}} : memref<?xindex> jointly memref<?xf32>
// CHECK: }
// CHECK: memref.store %[[A10]], %[[A27]]{{\[}}%[[A2]]] : memref<?xindex>
// CHECK: %[[A36:.*]] = sparse_tensor.storage_specifier.set %[[A30]] crd_mem_sz at 0 with %[[A11]]
diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
index b3eb50f1755dace..54cdfc690952d9a 100644
--- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir
@@ -178,7 +178,7 @@ func.func @sparse_convert_singleton(%arg0: tensor<?xf32, #SparseSingleton64>) ->
// CHECK-RWT: %[[VAL_16:.*]] = sparse_tensor.load %[[VAL_17:.*]] hasInserts : tensor<?x?x?xf32, #{{.*}}>>
// CHECK-RWT: %[[VAL_18:.*]] = sparse_tensor.values %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xf32>
// CHECK-RWT: %[[VAL_19:.*]] = sparse_tensor.coordinates_buffer %[[VAL_16]] : tensor<?x?x?xf32, #{{.*}}>> to memref<?xindex>
-// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {nx = 3 : index, ny = 0 : index}
+// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_7]], %[[VAL_19]] jointly %[[VAL_18]] {ny = 0 : index, perm_map = #map}
// CHECK-RWT: %[[VAL_20:.*]] = bufferization.alloc_tensor(%[[VAL_4]], %[[VAL_5]], %[[VAL_6]]) size_hint=%[[VAL_7]]
// CHECK-RWT: %[[VAL_21:.*]] = sparse_tensor.foreach in %[[VAL_16]] init(%[[VAL_20]])
// CHECK-RWT: ^bb0(%[[VAL_22:.*]]: index, %[[VAL_23:.*]]: index, %[[VAL_24:.*]]: index, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: tensor<?x?x?xf32, #{{.*}}>>):
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 71e6eebb30261c8..c0e813dcde7c57e 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -790,60 +790,51 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
return
}
-// -----
-
-// TODO: a test case with empty xs doesn't work due to some parser issues.
-
-func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
- // expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref<?xf32>
-}
-
-// -----
-
-func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
- %i20 = arith.constant 20 : index
- // expected-error at +1 {{xs and ys need to have a dimension >= n: 10 < 20}}
- sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex>
- return
-}
// -----
-func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
- // expected-error at +1 {{mismatch xs element types}}
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
- return
-}
-
-// -----
+#MAP = affine_map<(i,j) -> (i,j)>
func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
// expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
- sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref<?xf32>
+ sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 {perm_map = #MAP} : memref<?xf32>
return
}
// -----
+#MAP = affine_map<(i,j) -> (i,j)>
+
func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
%i20 = arith.constant 20 : index
- // expected-error at +1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
- sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
+ // expected-error at +1 {{Expected dimension(xy) >= n * (rank(perm_map) + ny) got 50 < 60}}
+ sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {perm_map = #MAP, ny = 1 : index} : memref<50xindex>
return
}
// -----
+#MAP = affine_map<(i,j) -> (i,j)>
+
func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
%i20 = arith.constant 20 : index
// expected-error at +1 {{Expected dimension(y) >= n got 10 < 20}}
- sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+ sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {perm_map = #MAP, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
return
}
// -----
+#NON_PERM_MAP = affine_map<(i,j) -> (i,i)>
+
+func.func @sparse_sort_coo_no_perm(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
+ // expected-error at +1 {{Expected a permutation map, got (d0, d1) -> (d0, d0)}}
+ sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #NON_PERM_MAP, ny = 1 : index}: memref<?xindex>
+ return %arg1 : memref<?xindex>
+}
+
+// -----
+
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index d1262cb7aea02df..d252fa559a1543f 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -612,79 +612,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
// -----
-// CHECK-LABEL: func @sparse_sort_1d0v(
-// CHECK-SAME: %[[A:.*]]: index,
-// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref<?xindex>
-// CHECK: return %[[B]]
-func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref<?xindex>
- return %arg1 : memref<?xindex>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_sort_1d2v(
-// CHECK-SAME: %[[A:.*]]: index,
-// CHECK-SAME: %[[B:.*]]: memref<20xindex>,
-// CHECK-SAME: %[[C:.*]]: memref<10xindex>,
-// CHECK-SAME: %[[D:.*]]: memref<?xf32>)
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
-// CHECK: return %[[B]], %[[C]], %[[D]]
-func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
- return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_sort_2d1v(
-// CHECK-SAME: %[[A:.*]]: index,
-// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
-// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
-// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
-// CHECK: return %[[B]], %[[C]], %[[D]]
-func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
- sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
- return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
-}
-
-// -----
-
-// CHECK-LABEL: func @sparse_sort_stable(
-// CHECK-SAME: %[[A:.*]]: index,
-// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
-// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
-// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
-// CHECK: return %[[B]], %[[C]], %[[D]]
-func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
- sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
- return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
-}
-
-// -----
+#ID_MAP = affine_map<(i,j) -> (i,j)>
// CHECK-LABEL: func @sparse_sort_coo(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref<?xindex>
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {ny = 1 : index, perm_map = #{{.*}}} : memref<?xindex>
// CHECK: return %[[B]]
func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
- sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref<?xindex>
+ sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xindex>
return %arg1 : memref<?xindex>
}
// -----
+#ID_MAP = affine_map<(i,j) -> (i,j)>
+
// CHECK-LABEL: func @sparse_sort_coo_stable(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xi64>,
// CHECK-SAME: %[[C:.*]]: memref<?xf32>)
-// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index}
+// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {ny = 1 : index, perm_map = #{{.*}}}
// CHECK: return %[[B]], %[[C]]
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
- sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
+ sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index b31ac3ef3a254ad..5c308dc3c56234b 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -116,7 +116,7 @@
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_64:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]] : memref<?xindex>
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[VAL_65:.*]], %[[VAL_33]]
// CHECK: %[[VAL_66:.*]]:4 = scf.for %[[VAL_67:.*]] = %[[VAL_10]] to %[[VAL_65]] step %[[VAL_11]] iter_args(%[[VAL_68:.*]] = %[[VAL_36]], %[[VAL_69:.*]] = %[[VAL_37]], %[[VAL_70:.*]] = %[[VAL_38]], %[[VAL_71:.*]] = %[[VAL_39]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[VAL_72:.*]] = memref.load %[[VAL_32]]{{\[}}%[[VAL_67]]] : memref<4xindex>
// CHECK: %[[VAL_73:.*]] = memref.load %[[VAL_30]]{{\[}}%[[VAL_72]]] : memref<4xf64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 0594b311184f4d5..394b9a8448b5438 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -28,7 +28,7 @@
// Do the same run, but now with VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-#ID_MAP = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#ID_MAP = affine_map<(d0, d1, d2) -> (d1, d2, d0)>
module {
// Stores 5 values to the memref buffer.
@@ -96,11 +96,11 @@ module {
%y1 = memref.cast %y1s : memref<7xi32> to memref<?xi32>
// Sort "parallel arrays".
- // CHECK: ( 1, 1, 3, 3, 10 )
- // CHECK: ( 2, 10, 1, 1, 5 )
- // CHECK: ( 4, 2, 9, 9, 7 )
- // CHECK: ( 10, 6, 7, 8, 9 )
- // CHECK: ( 7, 5, 7, 4, 9 )
+ // CHECK: ( 1, 1, 2, 5, 10 )
+ // CHECK: ( 9, 9, 4, 7, 2 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 7, 8, 10, 9, 6 )
+ // CHECK: ( 7, 4, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -111,24 +111,25 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
+ sparse_tensor.sort_coo quick_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
- %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
- vector.print %x0v : vector<5xi32>
+ // Dumps memory in the same order as the perm_map such that the output is ordered.
%x1v = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x1v : vector<5xi32>
%x2v = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x2v : vector<5xi32>
+ %x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x0v : vector<5xi32>
%y0v = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %y0v : vector<5xi32>
%y1v = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
vector.print %y1v : vector<5xi32>
// Stable sort.
- // CHECK: ( 1, 1, 3, 3, 10 )
- // CHECK: ( 2, 10, 1, 1, 5 )
- // CHECK: ( 4, 2, 9, 9, 7 )
- // CHECK: ( 10, 6, 8, 7, 9 )
- // CHECK: ( 7, 5, 4, 7, 9 )
+ // CHECK: ( 1, 1, 2, 5, 10 )
+ // CHECK: ( 9, 9, 4, 7, 2 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 8, 7, 10, 9, 6 )
+ // CHECK: ( 4, 7, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -139,24 +140,24 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
+ sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
- %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
- vector.print %x0v2 : vector<5xi32>
%x1v2 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x1v2 : vector<5xi32>
%x2v2 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x2v2 : vector<5xi32>
+ %x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x0v2 : vector<5xi32>
%y0v2 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %y0v2 : vector<5xi32>
%y1v2 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
vector.print %y1v2 : vector<5xi32>
// Heap sort.
- // CHECK: ( 1, 1, 3, 3, 10 )
- // CHECK: ( 2, 10, 1, 1, 5 )
- // CHECK: ( 4, 2, 9, 9, 7 )
- // CHECK: ( 10, 6, 8, 7, 9 )
- // CHECK: ( 7, 5, 4, 7, 9 )
+ // CHECK: ( 1, 1, 2, 5, 10 )
+ // CHECK: ( 9, 9, 4, 7, 2 )
+ // CHECK: ( 3, 3, 1, 10, 1 )
+ // CHECK: ( 7, 8, 10, 9, 6 )
+ // CHECK: ( 7, 4, 7, 9, 5 )
call @storeValuesToStrided(%x0, %c1, %c1, %c3, %c10, %c3)
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesToStrided(%x1, %c10, %c2, %c1, %c5, %c1)
@@ -167,14 +168,14 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {nx = #ID_MAP, ny = 1 : index}
+ sparse_tensor.sort_coo heap_sort %i5, %xy jointly %y1 {perm_map = #ID_MAP, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
- %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
- vector.print %x0v3 : vector<5xi32>
%x1v3 = vector.transfer_read %x1[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x1v3 : vector<5xi32>
%x2v3 = vector.transfer_read %x2[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x2v3 : vector<5xi32>
+ %x0v3 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
+ vector.print %x0v3 : vector<5xi32>
%y0v3 = vector.transfer_read %y0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %y0v3 : vector<5xi32>
%y1v3 = vector.transfer_read %y1[%i0], %c100: memref<?xi32>, vector<5xi32>
More information about the Mlir-commits
mailing list