[Mlir-commits] [mlir] cf24d49 - [mlir][sparse] Add sparse_tensor.sort_coo operator.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 7 08:23:56 PST 2022
Author: bixia1
Date: 2022-11-07T08:23:51-08:00
New Revision: cf24d49dc81b06e8efff15bd77f332840180867c
URL: https://github.com/llvm/llvm-project/commit/cf24d49dc81b06e8efff15bd77f332840180867c
DIFF: https://github.com/llvm/llvm-project/commit/cf24d49dc81b06e8efff15bd77f332840180867c.diff
LOG: [mlir][sparse] Add sparse_tensor.sort_coo operator.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D137442
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index a22dcce4298ef..52a6aff752792 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -518,6 +518,45 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
let hasVerifier = 1;
}
+def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
+ Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
+ Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
+ OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
+ UnitAttr:$stable)> {
+ let summary = "Sorts the arrays in xs and ys lexicographically on the "
+ "integral values found in the xs list";
+ let description = [{
+ Sparse_tensor.sort_coo is similar to sparse_tensor.sort, except that all the
+ `xs` values and some `ys` values are put in the linear buffer `xy`. The
+ optional index attribute `nx` provides the number of `xs` values in `xy`.
+ When `ns` is not explicitly specified, its value is 1. The optional index
+ attribute `ny` provides the number of `ys` values in `xy`. When `ny` is not
+ explicitly specified, its value is 0. This instruction supports the TACO
+ COO style storage format for better sorting performance.
+
+ The buffer xy should have a dimension not less than n * (nx + ny) while the
+ buffers in `ys` should have a dimension not less than `n`. The behavior of
+ the operator is undefined if this condition is not met.
+
+ Example:
+
+ ```mlir
+ sparse_tensor.sort_coo %n, %x { nx = 2 : index}
+ : memref<?xindex>
+ ```
+
+ ```mlir
+ sparse_tensor.sort %n, %xy jointly %y1 { nx = 2 : index, ny = 2 : index}
+ : memref<?xi64> jointly memref<?xf32>
+ ```
+ }];
+
+ let assemblyFormat = "(`stable` $stable^)? $n"
+ "`,`$xy (`jointly` $ys^)? attr-dict"
+ "`:` type($xy) (`jointly` type($ys)^)?";
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Syntax Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 9d8cf37befd49..693af03a94cb5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -719,6 +719,42 @@ LogicalResult SortOp::verify() {
return success();
}
+LogicalResult SortCooOp::verify() {
+ auto cn = getN().getDefiningOp<arith::ConstantIndexOp>();
+ // We can't check the size of the buffers when n or buffer dimensions aren't
+ // compile-time constants.
+ if (!cn)
+ return success();
+
+ uint64_t n = cn.value();
+ uint64_t nx = 1;
+ if (auto nxAttr = getNxAttr()) {
+ nx = nxAttr.getInt();
+ if (nx < 1)
+ emitError(llvm::formatv("Expected nx > 1, got {0}", nx));
+ }
+ uint64_t ny = 0;
+ if (auto nyAttr = getNyAttr()) {
+ ny = nyAttr.getInt();
+ }
+
+ auto checkDim = [&](Value v, uint64_t min, const char *message) {
+ MemRefType tp = v.getType().cast<MemRefType>();
+ int64_t dim = tp.getShape()[0];
+ if (dim != ShapedType::kDynamicSize && dim < min) {
+ emitError(llvm::formatv("{0} got {1} < {2}", message, dim, min));
+ }
+ };
+
+ checkDim(getXy(), n * (nx + ny), "Expected dimension(xy) >= n * (nx + ny)");
+
+ for (Value opnd : getYs()) {
+ checkDim(opnd, n, "Expected dimension(y) >= n");
+ }
+
+ return success();
+}
+
LogicalResult YieldOp::verify() {
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 407f19401b86b..02fb97bc866c6 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -622,6 +622,32 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a
// -----
+func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
+ // expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
+ sparse_tensor.sort_coo %arg0, %arg1: memref<?xf32>
+ return
+}
+
+// -----
+
+func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
+ %i20 = arith.constant 20 : index
+ // expected-error at +1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
+ sparse_tensor.sort_coo %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
+ return
+}
+
+// -----
+
+func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
+ %i20 = arith.constant 20 : index
+ // expected-error at +1 {{Expected dimension(y) >= n got 10 < 20}}
+ sparse_tensor.sort_coo %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+ return
+}
+
+// -----
+
#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 7f850ccbbc4e2..bc664ae3d2d00 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -484,3 +484,18 @@ func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<
sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
}
+
+// -----
+
+func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
+ sparse_tensor.sort_coo %arg0, %arg1 { nx=2 : index, ny=1 : index}: memref<?xindex>
+ return %arg1 : memref<?xindex>
+}
+
+// -----
+
+func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
+ sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref<?xi64> jointly memref<?xf32>
+ return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
+}
+
More information about the Mlir-commits
mailing list