[Mlir-commits] [mlir] 4329ca6 - [mlir][sparse] Add sparse_tensor.sort operator.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 27 12:00:17 PDT 2022


Author: bixia1
Date: 2022-09-27T12:00:08-07:00
New Revision: 4329ca61e821bb594ffaa280481b7f92d8ad31b1

URL: https://github.com/llvm/llvm-project/commit/4329ca61e821bb594ffaa280481b7f92d8ad31b1
DIFF: https://github.com/llvm/llvm-project/commit/4329ca61e821bb594ffaa280481b7f92d8ad31b1.diff

LOG: [mlir][sparse] Add sparse_tensor.sort operator.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134482

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index ef1edf1fdf4f6..2d2a0499a7b9b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -392,6 +392,51 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
   let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
 }
 
+def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
+    // TODO: May want to extend tablegen with
+    // class NonemptyVariadic<Type type> : Variadic<type> { let minSize = 1; }
+    // and then use NonemptyVariadic<...>:$xs here.
+    Arguments<(ins Index:$n,
+               Variadic<StridedMemRefRankOf<[AnyInteger, Index], [1]>>:$xs,
+               Variadic<StridedMemRefRankOf<[AnyType], [1]>>:$ys)>  {
+  string summary = "Sorts the arrays in xs and ys lexicographically on the "
+                   "integral values found in the xs list";
+  string description = [{
+    Lexicographically sort the first `n` values in `xs` along with the values in
+    `ys`. Conceptually, the values being sorted are tuples produced by
+    `zip(zip(xs), zip(ys))`. In particular, values in `ys` needed to be sorted
+    along with values in `xs`, but values in `ys` don't affect the
+    lexicographical order. The order in which arrays appear in `xs` affects the
+    sorting result. The operator updates `xs` and `ys` in place with the result
+    of the sorting.
+
+    For example, assume x1=[4, 3], x2=[1, 2], y1=[10, 5], then the output of
+    "sort 2, x1, x2 jointly y1" are x1=[3, 4], x2=[2, 1], y1=[5, 10] while the
+    output of "sort 2, x2, x1, jointly y1" are x2=[1, 2], x1=[4, 3], y1=[10, 5].
+
+    Buffers in `xs` needs to have the same integral element type while buffers
+    in `ys` can have 
diff erent numeric element types. All buffers in `xs` and
+    `ys` should have a dimension not less than `n`. The behavior of the operator
+    is undefined if this condition is not met. The operator requires at least
+    one buffer in `xs` while `ys` can be empty.
+
+    Note that this operation is "impure" in the sense that its behavior is
+    solely defined by side-effects and not SSA values. The semantics may be
+    refined over time as our sparse abstractions evolve.
+
+    Example:
+
+    ```mlir
+    sparse_tensor.sort %n, %x1, %x2 jointly y1, %y2
+      : memref<?xindex>, memref<?xindex> jointly memref<?xindex>, memref<?xf32>
+    ```
+  }];
+  let assemblyFormat = "$n `,` $xs (`jointly` $ys^)? attr-dict"
+                       "`:` type($xs) (`jointly` type($ys)^)?";
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Syntax Operations.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index 5172031909b46..9dba53018015c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
   MLIRSparseTensorOpsIncGen
 
   LINK_LIBS PUBLIC
+  MLIRArithmeticDialect
   MLIRDialect
   MLIRIR
   MLIRInferTypeOpInterface

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 2e98eaa7561c7..ba344e3725199 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
 
+#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/Matchers.h"
@@ -505,6 +506,41 @@ LogicalResult SelectOp::verify() {
   return success();
 }
 
+LogicalResult SortOp::verify() {
+  if (getXs().empty())
+    return emitError("need at least one xs buffer.");
+
+  auto n = getN().getDefiningOp<arith::ConstantIndexOp>();
+
+  Type xtp = getXs().front().getType().cast<MemRefType>().getElementType();
+  auto checkTypes = [&](ValueRange operands,
+                        bool checkEleType = true) -> LogicalResult {
+    for (Value opnd : operands) {
+      MemRefType mtp = opnd.getType().cast<MemRefType>();
+      uint64_t dim = mtp.getShape()[0];
+      // We can't check the size of dynamic dimension at compile-time, but all
+      // xs and ys should have a dimension not less than n at runtime.
+      if (n && dim != ShapedType::kDynamicSize && dim < n.value())
+        return emitError(llvm::formatv("xs and ys need to have a dimension >= n"
+                                       ": {0} < {1}",
+                                       dim, n.value()));
+
+      if (checkEleType && xtp != mtp.getElementType())
+        return emitError("mismatch xs element types");
+    }
+    return success();
+  };
+
+  LogicalResult result = checkTypes(getXs());
+  if (failed(result))
+    return result;
+
+  if (n)
+    return checkTypes(getYs(), false);
+
+  return success();
+}
+
 LogicalResult YieldOp::verify() {
   // Check for compatible parent.
   auto *parentOp = (*this)->getParentOp();

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index af913204fabba..7425d19efb3b8 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -501,3 +501,29 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
   }
   return
 }
+
+// -----
+
+// TODO: a test case with empty xs doesn't work due to some parser issues.
+
+func.func @sparse_sort_x_type( %arg0: index, %arg1: memref<?xf32>) {
+  // expected-error at +1 {{operand #1 must be 1D memref of integer or index values}}
+  sparse_tensor.sort %arg0, %arg1: memref<?xf32>
+}
+
+// -----
+
+func.func @sparse_sort_dim_too_small(%arg0: memref<10xindex>) {
+  %i20 = arith.constant 20 : index
+  // expected-error at +1 {{xs and ys need to have a dimension >= n: 10 < 20}}
+  sparse_tensor.sort %i20, %arg0 : memref<10xindex>
+  return
+}
+
+// -----
+
+func.func @sparse_sort_mismatch_x_type(%arg0: index, %arg1: memref<10xindex>, %arg2: memref<10xi8>) {
+  // expected-error at +1 {{mismatch xs element types}}
+  sparse_tensor.sort %arg0, %arg1, %arg2 : memref<10xindex>, memref<10xi8>
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index fd4b508ad4852..7059b15905436 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -362,3 +362,43 @@ func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () {
   }
   return
 }
+
+// ----
+
+// CHECK-LABEL: func @sparse_sort_1d0v(
+//  CHECK-SAME: %[[A:.*]]: index,
+//  CHECK-SAME: %[[B:.*]]: memref<?xindex>)
+//       CHECK: sparse_tensor.sort %[[A]], %[[B]] : memref<?xindex>
+//       CHECK: return %[[B]]
+func.func @sparse_sort_1d0v(%arg0: index, %arg1: memref<?xindex>) -> (memref<?xindex>) {
+  sparse_tensor.sort %arg0, %arg1 : memref<?xindex>
+  return %arg1 : memref<?xindex>
+}
+
+// -----
+
+// CHECK-LABEL: func @sparse_sort_1d2v(
+//  CHECK-SAME: %[[A:.*]]: index,
+//  CHECK-SAME: %[[B:.*]]: memref<20xindex>,
+//  CHECK-SAME: %[[C:.*]]: memref<10xindex>,
+//  CHECK-SAME: %[[D:.*]]: memref<?xf32>)
+//       CHECK: sparse_tensor.sort %[[A]], %[[B]] jointly %[[C]], %[[D]] : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+//       CHECK: return %[[B]], %[[C]], %[[D]]
+func.func @sparse_sort_1d2v(%arg0: index, %arg1: memref<20xindex>, %arg2: memref<10xindex>, %arg3: memref<?xf32>) -> (memref<20xindex>, memref<10xindex>, memref<?xf32>) {
+  sparse_tensor.sort %arg0, %arg1 jointly %arg2, %arg3 : memref<20xindex> jointly memref<10xindex>, memref<?xf32>
+  return %arg1, %arg2, %arg3 : memref<20xindex>, memref<10xindex>, memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @sparse_sort_2d1v(
+//  CHECK-SAME: %[[A:.*]]: index,
+//  CHECK-SAME: %[[B:.*]]: memref<10xi8>,
+//  CHECK-SAME: %[[C:.*]]: memref<20xi8>,
+//  CHECK-SAME: %[[D:.*]]: memref<10xf64>)
+//       CHECK: sparse_tensor.sort %[[A]], %[[B]], %[[C]] jointly %[[D]] : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+//       CHECK: return %[[B]], %[[C]], %[[D]]
+func.func @sparse_sort_2d1v(%arg0: index, %arg1: memref<10xi8>, %arg2: memref<20xi8>, %arg3: memref<10xf64>) -> (memref<10xi8>, memref<20xi8>, memref<10xf64>) {
+  sparse_tensor.sort %arg0, %arg1, %arg2 jointly %arg3 : memref<10xi8>, memref<20xi8> jointly memref<10xf64>
+  return %arg1, %arg2, %arg3 : memref<10xi8>, memref<20xi8>, memref<10xf64>
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 135e54ce2973e..ccc21652b3838 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -2030,6 +2030,7 @@ cc_library(
     hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"],
     includes = ["include"],
     deps = [
+        ":ArithmeticDialect",
         ":IR",
         ":InferTypeOpInterface",
         ":SparseTensorAttrDefsIncGen",


        


More information about the Mlir-commits mailing list