[Mlir-commits] [mlir] 8c02ca1 - [mlir][sparse] Add an attribute to the sort operator for stable sorting.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 4 15:14:13 PDT 2022
Author: bixia1
Date: 2022-10-04T15:14:03-07:00
New Revision: 8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971
URL: https://github.com/llvm/llvm-project/commit/8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971
DIFF: https://github.com/llvm/llvm-project/commit/8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971.diff
LOG: [mlir][sparse] Add an attribute to the sort operator for stable sorting.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D135181
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 8cd1a01a2e330..4d1f23719ee08 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -415,7 +415,8 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
// and then use NonemptyVariadic<...>:$xs here.
Arguments<(ins Index:$n,
Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
- Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys)> {
+ Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
+ UnitAttr:$stable)> {
string summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
string description = [{
@@ -437,6 +438,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
is undefined if this condition is not met. The operator requires at least
one buffer in `xs` while `ys` can be empty.
+ The `stable` attribute indicates whether a stable sorting algorithm should
+ be used to implement the operator.
+
Note that this operation is "impure" in the sense that its behavior is
solely defined by side-effects and not SSA values. The semantics may be
refined over time as our sparse abstractions evolve.
@@ -447,10 +451,18 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```
+
+ ```mlir
+ sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2
+ : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
+ ```
}];
- let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict"
+ let assemblyFormat = "(`stable` $stable^)? $n"
+ "`,`$xs (`jointly` $ys^)? attr-dict"
"`:` type($xs) (`jointly` type($ys)^)?";
-
+ let builders = [
+ OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)>
+ ];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d12ecb9d023f6..9f12a5481e406 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -675,6 +675,11 @@ LogicalResult SelectOp::verify() {
return success();
}
+void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n,
+ ValueRange xs, ValueRange ys) {
+ build(odsBuilder, odsState, n, xs, ys, /*stable=*/false);
+}
+
LogicalResult SortOp::verify() {
if (getXs().empty())
return emitError("need at least one xs buffer.");
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index ee69af3d256a2..fd850aacacae7 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -423,3 +423,17 @@ func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20
sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
}
+
+// -----
+
+// CHECK-LABEL: func @sparse_sort_stable(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
+// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
+// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
+// CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+// CHECK: return %[[B]], %[[C]], %[[D]]
+func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
+ sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+ return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
+}
More information about the Mlir-commits
mailing list