[Mlir-commits] [mlir] 0c7f1c1 - [mlir][sparse] Extend sparse_tensor.sort with a enum attribute to specify a sorting implementation.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 29 18:34:13 PST 2023
Author: bixia1
Date: 2023-01-29T18:34:08-08:00
New Revision: 0c7f1c152021249a97640d0ec9396e3885b9dbcc
URL: https://github.com/llvm/llvm-project/commit/0c7f1c152021249a97640d0ec9396e3885b9dbcc
DIFF: https://github.com/llvm/llvm-project/commit/0c7f1c152021249a97640d0ec9396e3885b9dbcc.diff
LOG: [mlir][sparse] Extend sparse_tensor.sort with a enum attribute to specify a sorting implementation.
Currently, all the non-stable sorting algorithms are implemented via the
straightforward quick sort. This will be fixed in the following PR.
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D142678
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
mlir/test/Dialect/SparseTensor/codegen.mlir
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 43c493c1e0f56..f6fc8fd4f848c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -325,4 +325,39 @@ class RankedSparseTensorOf<list<Type> allowedTypes>
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Sorting Algorithm Attribute.
+//===----------------------------------------------------------------------===//
+
+// TODO: Currently, we only provide four implementations, and expose the
+// implementations via attribute algorithm. In the future, if we will need
+// to support both stable and non-stable quick sort, we may add
+// quick_sort_nonstable enum to the attribute. Alternative, we may use two
+// attributes, (stable|nonstable, algorithm), to specify a sorting
+// implementation.
+//
+// --------------------------------------------------------------------------
+// | | hybrid_qsort| insertion_sort | qsort | heap_sort. |
+// |non-stable | Impl | X | Impl | Impl |
+// |stable | X | Impl | Not Impl | X |
+// --------------------------------------------------------------------------
+
+// The C++ enum for sparse tensor sort kind.
+def SparseTensorSortKindEnum
+ : I32EnumAttr<"SparseTensorSortKind", "sparse tensor sort algorithm", [
+ I32EnumAttrCase<"HybridQuickSort", 0, "hybrid_quick_sort">,
+ I32EnumAttrCase<"InsertionSortStable", 1, "insertion_sort_stable">,
+ I32EnumAttrCase<"QuickSort", 2, "quick_sort">,
+ I32EnumAttrCase<"HeapSort", 3, "heap_sort">,
+ ]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = SparseTensor_Dialect.cppNamespace;
+}
+
+// Define the enum sparse tensor sort kind attribute.
+def SparseTensorSortKindAttr
+ : EnumAttr<SparseTensor_Dialect, SparseTensorSortKindEnum,
+ "SparseTensorSortAlgorithm"> {
+}
+
#endif // SPARSETENSOR_ATTRDEFS
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 4690643709143..521df943d657c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -581,10 +581,15 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
// TODO: May want to extend tablegen with
// class NonemptyVariadic<Type type> : Variadic<type> { let minSize = 1; }
// and then use NonemptyVariadic<...>:$xs here.
+ //
+ // TODO: Currently tablegen doesn't support the assembly syntax when
+ // `algorithm` is an optional enum attribute. We may want to use an optional
+ // enum attribute when this is fixed in tablegen.
+ //
Arguments<(ins Index:$n,
Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
- UnitAttr:$stable)> {
+ SparseTensorSortKindAttr:$algorithm)> {
string summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
string description = [{
@@ -606,8 +611,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
is undefined if this condition is not met. The operator requires at least
one buffer in `xs` while `ys` can be empty.
- The `stable` attribute indicates whether a stable sorting algorithm should
- be used to implement the operator.
+ The enum attribute `algorithm` indicates the sorting algorithm used to
+ implement the operator: hybrid_quick_sort, insertion_sort_stable,
+ quick_sort, or heap_sort.
Note that this operation is "impure" in the sense that its behavior is
solely defined by side-effects and not SSA values.
@@ -615,17 +621,17 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
Example:
```mlir
- sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2
+ sparse_tensor.sort insertion_sort_stable %n, %x1, %x2 jointly y1, %y2
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```
```mlir
- sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2
+ sparse_tensor.sort hybrid_quick_sort %n, %x1, %x2 jointly y1, %y2
+ { alg=1 : index}
: memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
```
}];
- let assemblyFormat = "(`stable` $stable^)? $n"
- "`,`$xs (`jointly` $ys^)? attr-dict"
+ let assemblyFormat = "$algorithm $n `,` $xs (`jointly` $ys^)? attr-dict"
"`:` type($xs) (`jointly` type($ys)^)?";
let hasVerifier = 1;
}
@@ -634,7 +640,7 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Arguments<(ins Index:$n, StridedMemRefRankOf<[AnyInteger, Index], [1]>:$xy,
Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
OptionalAttr<IndexAttr>:$nx, OptionalAttr<IndexAttr>:$ny,
- UnitAttr:$stable)> {
+ SparseTensorSortKindAttr:$algorithm)> {
let summary = "Sorts the arrays in xs and ys lexicographically on the "
"integral values found in the xs list";
let description = [{
@@ -653,17 +659,18 @@ def SparseTensor_SortCooOp : SparseTensor_Op<"sort_coo">,
Example:
```mlir
- sparse_tensor.sort_coo %n, %x { nx = 2 : index}
+ sparse_tensor.sort_coo insertion_sort_stable %n, %x { nx = 2 : index}
: memref<?xindex>
```
```mlir
- sparse_tensor.sort %n, %xy jointly %y1 { nx = 2 : index, ny = 2 : index}
+ sparse_tensor.sort hybrid_quick_sort %n, %xy jointly %y1
+ { nx = 2 : index, ny = 2 : index}
: memref<?xi64> jointly memref<?xf32>
```
}];
- let assemblyFormat = "(`stable` $stable^)? $n"
+ let assemblyFormat = "$algorithm $n"
"`,`$xy (`jointly` $ys^)? attr-dict"
"`:` type($xy) (`jointly` type($ys)^)?";
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
index 3fc760ee756cb..90dadf71e61b6 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp
@@ -814,11 +814,13 @@ LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, uint64_t nx,
}
operands.push_back(v);
}
+ bool isStable =
+ (op.getAlgorithm() == SparseTensorSortKind::InsertionSortStable);
auto insertPoint = op->template getParentOfType<func::FuncOp>();
- SmallString<32> funcName(op.getStable() ? kSortStableFuncNamePrefix
- : kSortNonstableFuncNamePrefix);
+ SmallString<32> funcName(isStable ? kSortStableFuncNamePrefix
+ : kSortNonstableFuncNamePrefix);
FuncGeneratorType funcGenerator =
- op.getStable() ? createSortStableFunc : createSortNonstableFunc;
+ isStable ? createSortStableFunc : createSortNonstableFunc;
FlatSymbolRefAttr func =
getMangledSortHelperFunc(rewriter, insertPoint, TypeRange(), funcName, nx,
ny, isCoo, operands, funcGenerator);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 2ce29e59029d9..f96aeeeab4027 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -819,7 +819,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
// in the "added" array prior to applying the compression.
unsigned rank = dstType.getShape().size();
if (isOrderedDim(dstType, rank - 1))
- rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{});
+ rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{},
+ SparseTensorSortKind::HybridQuickSort);
// While performing the insertions, we also need to reset the elements
// of the values/filled-switch by only iterating over the set elements,
// to ensure that the runtime complexity remains proportional to the
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 22ec4791066f3..1e2a7b017b7ff 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -866,9 +866,9 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
get1DMemRefType(getIndexOverheadType(rewriter, encSrc),
/*withLayout=*/false);
Value xs = rewriter.create<ToIndicesBufferOp>(loc, indTp, src);
- rewriter.create<SortCooOp>(loc, nnz, xs, ValueRange{y},
- rewriter.getIndexAttr(rank),
- rewriter.getIndexAttr(0));
+ rewriter.create<SortCooOp>(
+ loc, nnz, xs, ValueRange{y}, rewriter.getIndexAttr(rank),
+ rewriter.getIndexAttr(0), SparseTensorSortKind::HybridQuickSort);
} else {
// Gather the indices-arrays in the dst tensor storage order.
SmallVector<Value> xs(rank, Value());
@@ -877,7 +877,8 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
xs[toStoredDim(encDst, orgDim)] =
genToIndices(rewriter, loc, src, i, /*cooStart=*/0);
}
- rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y});
+ rewriter.create<SortOp>(loc, nnz, xs, ValueRange{y},
+ SparseTensorSortKind::HybridQuickSort);
}
}
diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
index 7e10ae17a1c44..b9d56f470d934 100644
--- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
+++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir
@@ -55,7 +55,7 @@ func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f64) -> (
// CHECK: return %[[M]], %[[S2]] : memref<?xf64>, index
func.func @sparse_push_back_n(%arg0: index, %arg1: memref<?xf64>, %arg2: f64, %arg3: index) -> (memref<?xf64>, index) {
%0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 : index, memref<?xf64>, f64, index
- return %0#0, %0#1 : memref<?xf64>, index
+ return %0#0, %0#1 : memref<?xf64>, index
}
// -----
@@ -155,7 +155,7 @@ func.func @sparse_push_back_inbound(%arg0: index, %arg1: memref<?xf64>, %arg2: f
// CHECK: }
func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?xf32>, %arg3: memref<10xindex>)
-> (memref<10xi8>, memref<?xf32>, memref<10xindex>) {
- sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<10xi8> jointly memref<?xf32>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<?xf32>, memref<10xindex>
}
@@ -170,7 +170,7 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<?x
// CHECK-DAG: func.func private @_sparse_sort_nonstable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d
func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
- sparse_tensor.sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
@@ -184,7 +184,7 @@ func.func @sparse_sort_3d(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?
// CHECK-DAG: func.func private @_sparse_sort_stable_3_index(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xindex>, %arg4: memref<?xindex>) {
// CHECK-LABEL: func.func @sparse_sort_3d_stable
func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<?xindex>, %arg3: memref<10xindex>) -> (memref<10xindex>, memref<?xindex>, memref<10xindex>) {
- sparse_tensor.sort stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
+ sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
return %arg1, %arg2, %arg3 : memref<10xindex>, memref<?xindex>, memref<10xindex>
}
@@ -199,7 +199,7 @@ func.func @sparse_sort_3d_stable(%arg0: index, %arg1: memref<10xindex>, %arg2: m
// CHECK-DAG: func.func private @_sparse_sort_nonstable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo
func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
- sparse_tensor.sort_coo %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
@@ -213,7 +213,7 @@ func.func @sparse_sort_coo(%arg0: index, %arg1: memref<100xindex>, %arg2: memref
// CHECK-DAG: func.func private @_sparse_sort_stable_2_index_coo_1_f32_i32(%arg0: index, %arg1: index, %arg2: memref<?xindex>, %arg3: memref<?xf32>, %arg4: memref<?xi32>) {
// CHECK-LABEL: func.func @sparse_sort_coo_stable
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<100xindex>, %arg2: memref<?xf32>, %arg3: memref<10xi32>) -> (memref<100xindex>, memref<?xf32>, memref<10xi32>) {
- sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
+ sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref<?xf32>, memref<10xi32>
return %arg1, %arg2, %arg3 : memref<100xindex>, memref<?xf32>, memref<10xi32>
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index 61c4324cf1a41..1e06e65212468 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -430,7 +430,7 @@ func.func @sparse_expansion3(%arg0: index, %arg1: index) -> memref<?xindex> {
// CHECK-DAG: %[[A9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[A10:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[A11:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref<?xindex>
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
// CHECK: %[[A12:.*]]:4 = scf.for %[[A13:.*]] = %[[A11]] to %[[A7]] step %[[A10]] iter_args(%[[A14:.*]] = %[[A0]], %[[A15:.*]] = %[[A1]], %[[A16:.*]] = %[[A2]], %[[A17:.*]] = %[[A3]])
// CHECK: %[[A18:.*]] = memref.load %[[A6]]{{\[}}%[[A13]]] : memref<?xindex>
// CHECK: %[[A19:.*]] = memref.load %[[A4]]{{\[}}%[[A18]]] : memref<?xf64>
@@ -478,7 +478,7 @@ func.func @sparse_compression_1d(%tensor: tensor<100xf64, #SV>,
// CHECK: %[[A11:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[A12:.*]] = arith.constant 1 : index
// CHECK: %[[A13:.*]] = arith.constant 0 : index
-// CHECK: sparse_tensor.sort %[[A7]], %[[A6]] : memref<?xindex>
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A7]], %[[A6]] : memref<?xindex>
// CHECK: %[[A14:.*]]:4 = scf.for %[[A15:.*]] = %[[A13]] to %[[A7]] step %[[A12]] iter_args(%[[A16:.*]] = %[[A0]], %[[A17:.*]] = %[[A1]], %[[A18:.*]] = %[[A2]], %[[A19:.*]] = %[[A3]]) -> (memref<?xi32>, memref<?xi64>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[A20:.*]] = memref.load %[[A6]]{{\[}}%[[A15]]] : memref<?xindex>
// CHECK: %[[A21:.*]] = memref.load %[[A4]]{{\[}}%[[A20]]] : memref<?xf64>
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 6bdd5cd346945..2646b2db71148 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -195,7 +195,7 @@ func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{
// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]]
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]]
// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]]
-// CHECK-RWT: sparse_tensor.sort_coo %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
+// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index}
// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]])
// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 531a987a6315e..feb45e184d67a 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -696,7 +696,7 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
// expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
- sparse_tensor.sort %arg0, %arg1: memref<?xf32>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1: memref<?xf32>
}
// -----
@@ -704,7 +704,7 @@ func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
%i20 = arith.constant 20 : index
// expected-error at +1 {{xs and ys need to have a dimension >= n: 10 < 20}}
- sparse_tensor.sort %i20, %arg0 : memref<10xindex>
+ sparse_tensor.sort insertion_sort_stable %i20, %arg0 : memref<10xindex>
return
}
@@ -712,7 +712,7 @@ func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
// expected-error at +1 {{mismatch xs element types}}
- sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
return
}
@@ -720,7 +720,7 @@ func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %a
func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
// expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
- sparse_tensor.sort_coo %arg0, %arg1: memref<?xf32>
+ sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1: memref<?xf32>
return
}
@@ -729,7 +729,7 @@ func.func @sparse_sort_coo_x_type( %arg0: index, %arg1: memref<?xf32>) {
func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
%i20 = arith.constant 20 : index
// expected-error at +1 {{Expected dimension(xy) >= n * (nx + ny) got 50 < 60}}
- sparse_tensor.sort_coo %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
+ sparse_tensor.sort_coo hybrid_quick_sort %i20, %arg0 {nx = 2 : index, ny = 1 : index} : memref<50xindex>
return
}
@@ -738,7 +738,7 @@ func.func @sparse_sort_coo_x_too_small(%arg0: memref<50xindex>) {
func.func @sparse_sort_coo_y_too_small(%arg0: memref<60xindex>, %arg1: memref<10xf32>) {
%i20 = arith.constant 20 : index
// expected-error at +1 {{Expected dimension(y) >= n got 10 < 20}}
- sparse_tensor.sort_coo %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
+ sparse_tensor.sort_coo insertion_sort_stable %i20, %arg0 jointly %arg1 {nx = 2 : index, ny = 1 : index} : memref<60xindex> jointly memref<10xf32>
return
}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 58375d6e3f3c5..1f48953f95fce 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -504,10 +504,10 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> (
// CHECK-LABEL: func @sparse_sort_1d0v(
// CHECK-SAME: %[[A:.*]]: index,
// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
-// CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref<?xindex>
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] : memref<?xindex>
// CHECK: return %[[B]]
func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
- sparse_tensor.sort %arg0, %arg1 : memref<?xindex>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 : memref<?xindex>
return %arg1 : memref<?xindex>
}
@@ -518,10 +518,10 @@ func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xi
// CHECK-SAME: %[[B:.*]]: memref<20xindex>,
// CHECK-SAME: %[[C:.*]]: memref<10xindex>,
// CHECK-SAME: %[[D:.*]]: memref<?xf32>)
-// CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
// CHECK: return %[[B]], %[[C]], %[[D]]
func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
- sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
}
@@ -532,10 +532,10 @@ func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref
// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-// CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
// CHECK: return %[[B]], %[[C]], %[[D]]
func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
- sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+ sparse_tensor.sort hybrid_quick_sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
}
@@ -546,23 +546,34 @@ func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20
// CHECK-SAME: %[[B:.*]]: memref<10xi8>,
// CHECK-SAME: %[[C:.*]]: memref<20xi8>,
// CHECK-SAME: %[[D:.*]]: memref<10xf64>)
-// CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+// CHECK: sparse_tensor.sort insertion_sort_stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
// CHECK: return %[[B]], %[[C]], %[[D]]
func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
- sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+ sparse_tensor.sort insertion_sort_stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
}
// -----
+// CHECK-LABEL: func @sparse_sort_coo(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: memref<?xindex>)
+// CHECK: sparse_tensor.sort_coo hybrid_quick_sort %[[A]], %[[B]] {nx = 2 : index, ny = 1 : index} : memref<?xindex>
+// CHECK: return %[[B]]
func.func @sparse_sort_coo(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
- sparse_tensor.sort_coo %arg0, %arg1 { nx=2 : index, ny=1 : index}: memref<?xindex>
+ sparse_tensor.sort_coo hybrid_quick_sort %arg0, %arg1 {nx = 2 : index, ny = 1 : index}: memref<?xindex>
return %arg1 : memref<?xindex>
}
// -----
+// CHECK-LABEL: func @sparse_sort_coo_stable(
+// CHECK-SAME: %[[A:.*]]: index,
+// CHECK-SAME: %[[B:.*]]: memref<?xi64>,
+// CHECK-SAME: %[[C:.*]]: memref<?xf32>)
+// CHECK: sparse_tensor.sort_coo insertion_sort_stable %[[A]], %[[B]] jointly %[[C]] {nx = 2 : index, ny = 1 : index}
+// CHECK: return %[[B]], %[[C]]
func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: memref<?xf32>) -> (memref<?xi64>, memref<?xf32>) {
- sparse_tensor.sort_coo stable %arg0, %arg1 jointly %arg2 { nx=2 : index, ny=1 : index}: memref<?xi64> jointly memref<?xf32>
+ sparse_tensor.sort_coo insertion_sort_stable %arg0, %arg1 jointly %arg2 {nx = 2 : index, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
index 90247c33ec529..70a5fa1338ad9 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_codegen.mlir
@@ -125,7 +125,7 @@
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: scf.yield %[[VAL_70:.*]] : index
// CHECK: } {"Emitted from" = "linalg.generic"}
-// CHECK: sparse_tensor.sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
+// CHECK: sparse_tensor.sort hybrid_quick_sort %[[VAL_71:.*]], %[[VAL_39]] : memref<?xindex>
// CHECK: %[[VAL_72:.*]]:4 = scf.for %[[VAL_73:.*]] = %[[VAL_11]] to %[[VAL_71]] step %[[VAL_12]] iter_args(%[[VAL_74:.*]] = %[[VAL_42]], %[[VAL_75:.*]] = %[[VAL_43]], %[[VAL_76:.*]] = %[[VAL_44]], %[[VAL_77:.*]] = %[[VAL_45]]) -> (memref<?xindex>, memref<?xindex>, memref<?xf64>, !sparse_tensor.storage_specifier
// CHECK: %[[VAL_78:.*]] = memref.load %[[VAL_38]]{{\[}}%[[VAL_73]]] : memref<4xindex>
// CHECK: %[[VAL_79:.*]] = memref.load %[[VAL_36]]{{\[}}%[[VAL_78]]] : memref<4xf64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
index 3dc9d5361a22a..19585488dad7d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort.mlir
@@ -50,22 +50,22 @@ module {
// Sort 0 elements.
// CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort %i0, %x0 : memref<?xi32>
+ sparse_tensor.sort hybrid_quick_sort %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [10, 2, 0, 5, 1]
- sparse_tensor.sort stable %i0, %x0 : memref<?xi32>
+ sparse_tensor.sort insertion_sort_stable %i0, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Sort the first 4 elements, with the last valid value untouched.
// CHECK: [0, 2, 5, 10, 1]
- sparse_tensor.sort %i4, %x0 : memref<?xi32>
+ sparse_tensor.sort hybrid_quick_sort %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Stable sort.
// CHECK: [0, 2, 5, 10, 1]
call @storeValuesTo(%x0, %c10, %c2, %c0, %c5, %c1)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort stable %i4, %x0 : memref<?xi32>
+ sparse_tensor.sort insertion_sort_stable %i4, %x0 : memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
// Prepare more buffers of
diff erent dimensions.
@@ -89,7 +89,7 @@ module {
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort %i5, %x0, %x1, %x2 jointly %y0
+ sparse_tensor.sort hybrid_quick_sort %i5, %x0, %x1, %x2 jointly %y0
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
@@ -108,7 +108,7 @@ module {
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y0, %c6, %c10, %c8, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort stable %i5, %x0, %x1, %x2 jointly %y0
+ sparse_tensor.sort insertion_sort_stable %i5, %x0, %x1, %x2 jointly %y0
: memref<?xi32>, memref<?xi32>, memref<?xi32> jointly memref<?xi32>
call @printMemref1dI32(%x0) : (memref<?xi32>) -> ()
call @printMemref1dI32(%x1) : (memref<?xi32>) -> ()
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
index 7c27a076837d8..b0ff0cf19c767 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_sort_coo.mlir
@@ -92,7 +92,7 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ sparse_tensor.sort_coo hybrid_quick_sort %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v : vector<5xi32>
@@ -120,7 +120,7 @@ module {
: (memref<?xi32, strided<[4], offset: ?>>, i32, i32, i32, i32, i32) -> ()
call @storeValuesTo(%y1, %c5, %c7, %c4, %c9, %c7)
: (memref<?xi32>, i32, i32, i32, i32, i32) -> ()
- sparse_tensor.sort_coo stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
+ sparse_tensor.sort_coo insertion_sort_stable %i5, %xy jointly %y1 {nx = 3 : index, ny = 1 : index}
: memref<?xi32> jointly memref<?xi32>
%x0v2 = vector.transfer_read %x0[%i0], %c100: memref<?xi32, strided<[4], offset: ?>>, vector<5xi32>
vector.print %x0v2 : vector<5xi32>
More information about the Mlir-commits
mailing list