[Mlir-commits] [mlir] 8c02ca1 - [mlir][sparse] Add an attribute to the sort operator for stable sorting.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 4 15:14:13 PDT 2022


Author: bixia1
Date: 2022-10-04T15:14:03-07:00
New Revision: 8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971

URL: https://github.com/llvm/llvm-project/commit/8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971
DIFF: https://github.com/llvm/llvm-project/commit/8c02ca1da5bdcb7f7e850afb24d95bb6d82d8971.diff

LOG: [mlir][sparse] Add an attribute to the sort operator for stable sorting.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D135181

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 8cd1a01a2e330..4d1f23719ee08 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -415,7 +415,8 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     // and then use NonemptyVariadic<...>:$xs here.
     Arguments<(ins Index:$n,
                Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
-               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys)>  {
+               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys,
+               UnitAttr:$stable)>  {
   string summary = "Sorts the arrays in xs and ys lexicographically on the "
                    "integral values found in the xs list";
   string description = [{
@@ -437,6 +438,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     is undefined if this condition is not met. The operator requires at least
     one buffer in `xs` while `ys` can be empty.
 
+    The `stable` attribute indicates whether a stable sorting algorithm should
+    be used to implement the operator.
+
     Note that this operation is "impure" in the sense that its behavior is
     solely defined by side-effects and not SSA values. The semantics may be
     refined over time as our sparse abstractions evolve.
@@ -447,10 +451,18 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
     sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2
       : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
     ```
+
+    ```mlir
+    sparse_tensor.sort stable %n, %x1, %x2 jointly y1, %y2
+      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
+    ```
   }];
-  let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict"
+  let assemblyFormat = "(`stable` $stable^)? $n"
+                       "`,`$xs (`jointly` $ys^)? attr-dict"
                        "`:` type($xs) (`jointly` type($ys)^)?";
-
+  let builders = [
+        OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)>
+  ];
   let hasVerifier = 1;
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index d12ecb9d023f6..9f12a5481e406 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -675,6 +675,11 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
+void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n,
+                   ValueRange xs, ValueRange ys) {
+  build(odsBuilder, odsState, n, xs, ys, /*stable=*/false);
+}
+
 LogicalResult SortOp::verify() {
   if (getXs().empty())
     return emitError("need at least one xs buffer.");

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index ee69af3d256a2..fd850aacacae7 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -423,3 +423,17 @@ func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20
   sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
   return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
 }
+
+// -----
+
+// CHECK-LABEL: func @sparse_sort_stable(
+//  CHECK-SAME: %[[A:.*]]: index,
+//  CHECK-SAME: %[[B:.*]]: memref<10xi8>,
+//  CHECK-SAME: %[[C:.*]]: memref<20xi8>,
+//  CHECK-SAME: %[[D:.*]]: memref<10xf64>)
+//       CHECK: sparse_tensor.sort stable %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+//       CHECK: return %[[B]], %[[C]], %[[D]]
+func.func @sparse_sort_stable(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
+  sparse_tensor.sort stable %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+  return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
+}


        


More information about the Mlir-commits mailing list