[Mlir-commits] [mlir] [mlir][sparse] unifies sparse_tensor.sort_coo/sort into one operation. (PR #66722)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 18 17:09:33 PDT 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
<details>
<summary>Changes</summary>
The functionality of the two operations are largely overlapped, let's simplify it and only use one of them.
---
Patch is 91.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66722.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+13-62)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+10-35)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp (+142-172)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+7-5)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp (-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+16-22)
- (modified) mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir (+23-82)
- (modified) mlir/test/Dialect/SparseTensor/codegen.mlir (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+19-28)
- (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+7-57)
- (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir (+1-1)
- (removed) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir (-187)
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir (+27-24)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 94301dbcd9f7b42..59815fc755ee5f3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -762,81 +762,32 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
// Sparse Tensor Sorting Operations.
//===----------------------------------------------------------------------===//
-def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
- Arguments<(ins Index:$n,
- Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
- Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
- SparseTensorSortKindAttr:$algorithm)> {
- string summary = "Sorts the arrays in xs and ys lexicographically on the "
- "integral values found in the xs list";
- string description = [{
- Lexicographically sort the first `n` values in `xs` along with the values in
- `ys`. Conceptually, the values being sorted are tuples produced by
- `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
- along with values in `xs`, but values in `ys` don't affect the
- lexicographical order. The order in which arrays appear in `xs` affects the
- sorting result. The operator updates `xs` and `ys` in place with the result
- of the sorting.
-
- For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
- "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
- output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
-
- Buffers in `xs` needs to have the same integral element type while buffers
- in `ys` can have different numeric element types. All buffers in `xs` and
- `ys` should have a dimension not less than `n`. The behavior of the operator
- is undefined if this condition is not met. The operator requires at least
- one buffer in `xs` while `ys` can be empty.
-
- The enum attribute `algorithm` indicates the sorting algorithm used to
- implement the operator: hybrid_quick_sort, insertion_sort_stable,
- quick_sort, or heap_sort.
-
- Note that this operation is "impure" in the sense that its behavior is
- solely defined by side-effects and not SSA values.
-
- Example:
-
- ```mlir
- sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
- : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
- ```
-
- ```mlir
- sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
- { alg=1 : index}
- : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
- ```
- }];
- let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
- "`:` type($xs) (`jointly` type($ys)^)?";
- let hasVerifier = 1;
-}
-
def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
- OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+ AffineMapAttr:$perm_map, OptionalAttr<IndexAttr>:$ny,
SparseTensorSortKindAttr:$algorithm)> {
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
let description = [{
- Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
- `xs` values and some `ys` values are put in the linear buffer `xy`. The
- optional index attribute `nx` provides the number of `xs` values in `xy`.
- When `nx` is not explicitly specified, its value is 1. The optional index
- attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
- explicitly specified, its value is 0. This instruction supports a more
- efficient way to store the COO definition in sparse tensor type.
-
- The buffer xy should have a dimension not less than n * (nx + ny) while the
+ Sparse_tensor.sort_coo sort the `xs` values along with some `ys` values
+ that are put in a single linear buffer `xy`.
+ The affine map attribute `perm_map` specifies the permutation to be applied on
+ the `xs` before comparison, the rank of the permutation map
+ also specifies the number of `xs` values in `xy`.
+ The optional index attribute `ny` provides the number of `ys` values in `xy`.
+ When `ny` is not explicitly specified, its value is 0.
+ This instruction supports a more efficient way to store the COO definition
+ in sparse tensor type.
+
+ The buffer xy should have a dimension not less than n * (rank(perm_map) + ny) while the
buffers in `ys` should have a dimension not less than `n`. The behavior of
the operator is undefined if this condition is not met.
Example:
```mlir
- sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
+ sparse_tensor.sort_coo insertion_sort_stable %n, %x { perm_map = affine_map<(i,j) -> (j,i)> }
: memref<?xindex>
```
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e71d2a8dd623a6a..9675a61109477b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1353,35 +1353,15 @@ LogicalResult SelectOp::verify() {
return success();
}
-LogicalResult SortOp::verify() {
- if (getXs().empty())
- return emitError("need at least one xs buffer.");
-
- std::optional<int64_t> n = getConstantIntValue(getN());
-
- Type xtp = getMemRefType(getXs().front()).getElementType();
- auto checkTypes = [&](ValueRange operands,
- bool checkEleType = true) -> LogicalResult {
- for (Value opnd : operands) {
- auto mtp = getMemRefType(opnd);
- const DynSize sh = mtp.getShape()[0];
- // We can't check the size of dynamic dimension at compile-time, but all
- // xs and ys should have a dimension not less than n at runtime.
- if (n && !ShapedType::isDynamic(sh) && sh < n.value())
- return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
- ": {0} < {1}",
- sh, n.value()));
-
- if (checkEleType && xtp != mtp.getElementType())
- return emitError("mismatch xs element types");
- }
- return success();
- };
- RETURN_FAILURE_IF_FAILED(checkTypes(getXs()))
- return n ? checkTypes(getYs(), false) : success();
-}
-
LogicalResult SortCooOp::verify() {
+ AffineMap xPerm = getPermMap();
+ uint64_t nx = xPerm.getNumDims();
+ if (nx < 1)
+ emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
+
+ if (!xPerm.isPermutation())
+ emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
+
std::optional<int64_t> cn = getConstantIntValue(getN());
// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
@@ -1389,12 +1369,6 @@ LogicalResult SortCooOp::verify() {
return success();
uint64_t n = cn.value();
- uint64_t nx = 1;
- if (auto nxAttr = getNxAttr()) {
- nx = nxAttr.getInt();
- if (nx < 1)
- emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
- }
uint64_t ny = 0;
if (auto nyAttr = getNyAttr()) {
ny = nyAttr.getInt();
@@ -1409,7 +1383,8 @@ LogicalResult SortCooOp::verify() {
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};
- checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+ checkDim(getXy(), n * (nx + ny),
+ "Expected dimension(xy) >= n * (rank(perm_map) + ny)");
for (Value opnd : getYs()) {
checkDim(opnd, n, "Expected dimension(y) >= n");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 029ecb0708941fe..3181395a474cfec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -45,46 +45,43 @@ static constexpr const char kShiftDownFuncNamePrefix[] = "_sparse_shift_down_";
static constexpr const char kHeapSortFuncNamePrefix[] = "_sparse_heap_sort_";
static constexpr const char kQuickSortFuncNamePrefix[] = "_sparse_qsort_";
-using FuncGeneratorType = function_ref<void(
- OpBuilder &, ModuleOp, func::FuncOp, uint64_t, uint64_t, bool, uint32_t)>;
+using FuncGeneratorType = function_ref<void(OpBuilder &, ModuleOp, func::FuncOp,
+ AffineMap, uint64_t, uint32_t)>;
/// Constructs a function name with this format to facilitate quick sort:
-/// <namePrefix><nx>_<x type>_<y0 type>..._<yn type> for sort
-/// <namePrefix><nx>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
+/// <namePrefix><xPerm>_<x type>_<y0 type>..._<yn type> for sort
+/// <namePrefix><xPerm>_<x type>_coo_<ny>_<y0 type>..._<yn type> for sort_coo
static void getMangledSortHelperFuncName(llvm::raw_svector_ostream &nameOstream,
- StringRef namePrefix, uint64_t nx,
- uint64_t ny, bool isCoo,
- ValueRange operands) {
- nameOstream << namePrefix << nx << "_"
- << getMemRefType(operands[xStartIdx]).getElementType();
+ StringRef namePrefix, AffineMap xPerm,
+ uint64_t ny, ValueRange operands) {
+ nameOstream << namePrefix;
+ for (auto res : xPerm.getResults())
+ nameOstream << res.cast<AffineDimExpr>().getPosition() << "_";
- if (isCoo)
- nameOstream << "_coo_" << ny;
+ nameOstream << getMemRefType(operands[xStartIdx]).getElementType();
+ nameOstream << "_coo_" << ny;
- uint64_t yBufferOffset = isCoo ? 1 : nx;
+ constexpr uint64_t yBufferOffset = 1;
for (Value v : operands.drop_front(xStartIdx + yBufferOffset))
nameOstream << "_" << getMemRefType(v).getElementType();
}
/// Looks up a function that is appropriate for the given operands being
/// sorted, and creates such a function if it doesn't exist yet. The
-/// parameters `nx` and `ny` tell the number of x and y values provided
-/// by the buffer in xStartIdx, and `isCoo` indicates whether the instruction
-/// being processed is a sparse_tensor.sort or sparse_tensor.sort_coo.
+/// parameters `xPerm` and `ny` tell the number of x and y values provided
+/// by the buffer in xStartIdx.
//
// All sorting function generators take (lo, hi, xs, ys) in `operands` as
// parameters for the sorting functions. Other parameters, such as the recursive
// call depth, are appended to the end of the parameter list as
// "trailing parameters".
-static FlatSymbolRefAttr
-getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
- TypeRange resultTypes, StringRef namePrefix,
- uint64_t nx, uint64_t ny, bool isCoo,
- ValueRange operands, FuncGeneratorType createFunc,
- uint32_t nTrailingP = 0) {
+static FlatSymbolRefAttr getMangledSortHelperFunc(
+ OpBuilder &builder, func::FuncOp insertPoint, TypeRange resultTypes,
+ StringRef namePrefix, AffineMap xPerm, uint64_t ny, ValueRange operands,
+ FuncGeneratorType createFunc, uint32_t nTrailingP = 0) {
SmallString<32> nameBuffer;
llvm::raw_svector_ostream nameOstream(nameBuffer);
- getMangledSortHelperFuncName(nameOstream, namePrefix, nx, ny, isCoo,
+ getMangledSortHelperFuncName(nameOstream, namePrefix, xPerm, ny,
operands.drop_back(nTrailingP));
ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
@@ -101,7 +98,7 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
loc, nameOstream.str(),
FunctionType::get(context, operands.getTypes(), resultTypes));
func.setPrivate();
- createFunc(builder, module, func, nx, ny, isCoo, nTrailingP);
+ createFunc(builder, module, func, xPerm, ny, nTrailingP);
}
return result;
@@ -110,27 +107,19 @@ getMangledSortHelperFunc(OpBuilder &builder, func::FuncOp insertPoint,
/// Creates a code block to process each pair of (xs[i], xs[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInXs(
- OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
- Value iOffset, jOffset;
- if (isCoo) {
- Value cstep = constantIndex(builder, loc, nx + ny);
- iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
- jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
- }
- for (uint64_t k = 0; k < nx; k++) {
- scf::IfOp ifOp;
- Value i, j, buffer;
- if (isCoo) {
- Value ck = constantIndex(builder, loc, k);
- i = builder.create<arith::AddIOp>(loc, ck, iOffset);
- j = builder.create<arith::AddIOp>(loc, ck, jOffset);
- buffer = args[xStartIdx];
- } else {
- i = args[0];
- j = args[1];
- buffer = args[xStartIdx + k];
- }
+ OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+ uint64_t ny,
+ function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+ Value cstep = constantIndex(builder, loc, xPerm.getNumResults() + ny);
+ Value iOffset = builder.create<arith::MulIOp>(loc, args[0], cstep);
+ Value jOffset = builder.create<arith::MulIOp>(loc, args[1], cstep);
+ for (unsigned k = 0, e = xPerm.getNumResults(); k < e; k++) {
+ unsigned actualK = xPerm.getResult(k).cast<AffineDimExpr>().getPosition();
+ Value ak = constantIndex(builder, loc, actualK);
+ Value i = builder.create<arith::AddIOp>(loc, ak, iOffset);
+ Value j = builder.create<arith::AddIOp>(loc, ak, jOffset);
+ Value buffer = args[xStartIdx];
+
bodyBuilder(k, i, j, buffer);
}
}
@@ -138,21 +127,28 @@ static void forEachIJPairInXs(
/// Creates a code block to process each pair of (xys[i], xys[j]) for sorting.
/// The code to process the value pairs is generated by `bodyBuilder`.
static void forEachIJPairInAllBuffers(
- OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
-
- // Create code for the first (nx + ny) buffers. When isCoo==true, these
- // logical buffers are all from the xy buffer of the sort_coo operator.
- forEachIJPairInXs(builder, loc, args, nx + ny, 0, isCoo, bodyBuilder);
+ OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+ uint64_t ny,
+ function_ref<void(uint64_t, Value, Value, Value)> bodyBuilder) {
+
+ // Create code for the first (xPerm + ny) buffers.
+ SmallVector<AffineExpr> exps(xPerm.getResults().begin(),
+ xPerm.getResults().end());
+ for (unsigned y = 0; y < ny; y++) {
+ exps.push_back(builder.getAffineDimExpr(y + xPerm.getNumResults()));
+ }
+ AffineMap xyPerm = AffineMap::get(exps.size(), 0, exps, builder.getContext());
+ assert(xyPerm.isPermutation());
- uint64_t numHandledBuffers = isCoo ? 1 : nx + ny;
+ forEachIJPairInXs(builder, loc, args, xyPerm, 0, bodyBuilder);
+ constexpr uint64_t numHandledBuffers = 1;
// Create code for the remaining buffers.
Value i = args[0];
Value j = args[1];
for (const auto &arg :
llvm::enumerate(args.drop_front(xStartIdx + numHandledBuffers))) {
- bodyBuilder(arg.index() + nx + ny, i, j, arg.value());
+ bodyBuilder(arg.index() + xPerm.getNumResults() + ny, i, j, arg.value());
}
}
@@ -168,7 +164,7 @@ static void forEachIJPairInAllBuffers(
// ...
// swap(yn[i], yn[j]);
static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
- uint64_t nx, uint64_t ny, bool isCoo) {
+ AffineMap xPerm, uint64_t ny) {
auto swapOnePair = [&](uint64_t unused, Value i, Value j, Value buffer) {
Value vi = builder.create<memref::LoadOp>(loc, buffer, i);
Value vj = builder.create<memref::LoadOp>(loc, buffer, j);
@@ -176,20 +172,20 @@ static void createSwap(OpBuilder &builder, Location loc, ValueRange args,
builder.create<memref::StoreOp>(loc, vi, buffer, j);
};
- forEachIJPairInAllBuffers(builder, loc, args, nx, ny, isCoo, swapOnePair);
+ forEachIJPairInAllBuffers(builder, loc, args, xPerm, ny, swapOnePair);
}
/// Creates code to compare all the (xs[i], xs[j]) pairs. The method to compare
/// each pair is create via `compareBuilder`.
static Value createInlinedCompareImplementation(
- OpBuilder &builder, Location loc, ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo,
+ OpBuilder &builder, Location loc, ValueRange args, AffineMap xPerm,
+ uint64_t ny,
function_ref<Value(OpBuilder &, Location, Value, Value, Value, bool, bool)>
compareBuilder) {
Value result;
auto bodyBuilder = [&](uint64_t k, Value i, Value j, Value buffer) {
bool isFirstDim = (k == 0);
- bool isLastDim = (k == nx - 1);
+ bool isLastDim = (k == xPerm.getNumResults() - 1);
Value val =
compareBuilder(builder, loc, i, j, buffer, isFirstDim, isLastDim);
if (isFirstDim) {
@@ -202,7 +198,7 @@ static Value createInlinedCompareImplementation(
}
};
- forEachIJPairInXs(builder, loc, args, nx, ny, isCoo, bodyBuilder);
+ forEachIJPairInXs(builder, loc, args, xPerm, ny, bodyBuilder);
builder.setInsertionPointAfterValue(result);
return result;
@@ -252,12 +248,12 @@ static Value createEqCompare(OpBuilder &builder, Location loc, Value i, Value j,
// else if (x2[2] != x2[j]))
// and so on ...
static Value createInlinedEqCompare(OpBuilder &builder, Location loc,
- ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ ValueRange args, AffineMap xPerm,
+ uint64_t ny, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
- return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+ return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
createEqCompare);
}
@@ -306,12 +302,12 @@ static Value createLessThanCompare(OpBuilder &builder, Location loc, Value i,
// else
// and so on ...
static Value createInlinedLessThan(OpBuilder &builder, Location loc,
- ValueRange args, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ ValueRange args, AffineMap xPerm,
+ uint64_t ny, uint32_t nTrailingP = 0) {
// Compare functions don't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
- return createInlinedCompareImplementation(builder, loc, args, nx, ny, isCoo,
+ return createInlinedCompareImplementation(builder, loc, args, xPerm, ny,
createLessThanCompare);
}
@@ -329,8 +325,8 @@ static Value createInlinedLessThan(OpBuilder &builder, Location loc,
// return lo;
//
static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
- func::FuncOp func, uint64_t nx, uint64_t ny,
- bool isCoo, uint32_t nTrailingP = 0) {
+ func::FuncOp func, AffineMap xPerm,
+ uint64_t ny, uint32_t nTrailingP = 0) {
// Binary search doesn't use trailing parameters.
(void)nTrailingP;
assert(nTrailingP == 0);
@@ -368,11 +364,10 @@ static void createBinarySearchFunc(OpBuilder &builder, ModuleOp module,
// Compare xs[p] < xs[mid].
SmallVector<Value> compareOperands{p, mid};
- uint64_t numXBuffers = isCoo ? 1 : nx;
+ constexpr uint64_t numXBuffers = 1;
compareOp...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/66722
More information about the Mlir-commits
mailing list