[Mlir-commits] [mlir] 07150fe - [mlir][sparse] Add sparse_tensor.select operation

Jim Kitchen llvmlistbot at llvm.org
Tue Sep 13 13:29:22 PDT 2022

Author: Jim Kitchen
Date: 2022-09-13T15:22:53-05:00
New Revision: 07150fece507d72bd35619b51e5cfd17ed2474ca

URL: https://github.com/llvm/llvm-project/commit/07150fece507d72bd35619b51e5cfd17ed2474ca
DIFF: https://github.com/llvm/llvm-project/commit/07150fece507d72bd35619b51e5cfd17ed2474ca.diff

LOG: [mlir][sparse] Add sparse_tensor.select operation

The new select operation allows filtering of sparse tensors
by conditionally keeping or removing each element. This
can be used to remove negative values or select the upper
triangle of a matrix.

The select op has a single region which operates on a single
value and must return a boolean True to keep or False to drop.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D133569




diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 28401dae69ca3..ed1943f20de2e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -604,11 +604,72 @@ def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [NoSideEffect, SameOperand
   let hasVerifier = 1;
+def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperandsAndResultType]>,
+    Arguments<(ins AnyType:$x)>,
+    Results<(outs AnyType:$output)> {
+  let summary = "Select operation utilized within linalg.generic";
+  let description = [{
+      Defines an evaluation within a `linalg.generic` operation that takes a single
+      operand and decides whether or not to keep that operand in the output.
+      A single region must contain exactly one block taking one argument. The block
+      must end with a sparse_tensor.yield and the output type must be boolean.
+      Value threshold is an obvious usage of the select operation. However, by using
+      `linalg.index`, other useful selection can be achieved, such as selecting the
+      upper triangle of a matrix.
+      Example of selecting A >= 4.0:
+      ```mlir
+      %C = bufferization.alloc_tensor...
+      %0 = linalg.generic #trait
+         ins(%A: tensor<?xf64, #SparseVector>)
+        outs(%C: tensor<?xf64, #SparseVector>) {
+        ^bb0(%a: f64, %c: f64) :
+          %result = sparse_tensor.select %a : f64 {
+              ^bb0(%arg0: f64):
+                %cf4 = arith.constant 4.0 : f64
+                %keep = arith.cmpf "uge", %arg0, %cf4 : f64
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %result : f64
+      } -> tensor<?xf64, #SparseVector>
+      ```
+      Example of selecting lower triangle of a matrix:
+      ```mlir
+      %C = bufferization.alloc_tensor...
+      %0 = linalg.generic #trait
+         ins(%A: tensor<?x?xf64, #CSR>)
+        outs(%C: tensor<?x?xf64, #CSR>) {
+        ^bb0(%a: f64, %c: f64) :
+          %row = linalg.index 0 : index
+          %col = linalg.index 1 : index
+          %result = sparse_tensor.select %a : f64 {
+              ^bb0(%arg0: f64):
+                %keep = arith.cmpf "olt", %col, %row : f64
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %result : f64
+      } -> tensor<?x?xf64, #CSR>
+      ```
+  }];
+  let regions = (region SizedRegion<1>:$region);
+  let assemblyFormat = [{
+         $x attr-dict `:` type($x) $region
+  }];
+  let hasVerifier = 1;
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
     Arguments<(ins AnyType:$result)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
-      Yields a value from within a `binary` or `unary` block.
+      Yields a value from within a `binary`, `unary`, `reduce`,
+      or `select` block.

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 3c553640624e8..c647b0bd0db7c 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -458,12 +458,27 @@ LogicalResult ReduceOp::verify() {
   // Check correct number of block arguments and return type.
   Region &formula = getRegion();
-  if (!formula.empty()) {
-    regionResult = verifyNumBlockArgs(
-        this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
-    if (failed(regionResult))
-      return regionResult;
-  }
+  regionResult = verifyNumBlockArgs(this, formula, "reduce",
+                                    TypeRange{inputType, inputType}, inputType);
+  if (failed(regionResult))
+    return regionResult;
+  return success();
+LogicalResult SelectOp::verify() {
+  Builder b(getContext());
+  Type inputType = getX().getType();
+  Type boolType = b.getI1Type();
+  LogicalResult regionResult = success();
+  // Check correct number of block arguments and return type.
+  Region &formula = getRegion();
+  regionResult = verifyNumBlockArgs(this, formula, "select",
+                                    TypeRange{inputType}, boolType);
+  if (failed(regionResult))
+    return regionResult;
   return success();
@@ -472,11 +487,11 @@ LogicalResult YieldOp::verify() {
   // Check for compatible parent.
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp))
+      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
     return success();
-  return emitOpError(
-      "expected parent op to be sparse_tensor unary, binary, or reduce");
+  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
+                     "reduce, or select");

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index c8fd0ddc1357d..c607dd2e77fee 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -355,6 +355,40 @@ func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
 // -----
+func.func @invalid_select_num_args_mismatch(%arg0: f64) -> f64 {
+  // expected-error at +1 {{select region must have exactly 1 arguments}}
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64, %y: f64):
+        %ret = arith.constant 1 : i1
+        sparse_tensor.yield %ret : i1
+    }
+  return %r : f64
+// -----
+func.func @invalid_select_return_type_mismatch(%arg0: f64) -> f64 {
+  // expected-error at +1 {{select region yield type mismatch}}
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64):
+        sparse_tensor.yield %x : f64
+    }
+  return %r : f64
+// -----
+func.func @invalid_select_wrong_yield(%arg0: f64) -> f64 {
+  // expected-error at +1 {{select region must end with sparse_tensor.yield}}
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64):
+        tensor.yield %x : f64
+    }
+  return %r : f64
+// -----
 #DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
 func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
   // expected-error at +1 {{Need at least two tensors to concatenate.}}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index b795f542d5adf..5c22ffb0cc69c 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -291,6 +291,30 @@ func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 {
 #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+// CHECK-LABEL: func @sparse_select(
+//  CHECK-SAME:   %[[A:.*]]: f64) -> f64 {
+//       CHECK:   %[[Z:.*]] = arith.constant 0.000000e+00 : f64
+//       CHECK:   %[[C1:.*]] = sparse_tensor.select %[[A]] : f64 {
+//       CHECK:       ^bb0(%[[A1:.*]]: f64):
+//       CHECK:         %[[B1:.*]] = arith.cmpf ogt, %[[A1]], %[[Z]] : f64
+//       CHECK:         sparse_tensor.yield %[[B1]] : i1
+//       CHECK:     }
+//       CHECK:   return %[[C1]] : f64
+//       CHECK: }
+func.func @sparse_select(%arg0: f64) -> f64 {
+  %cf0 = arith.constant 0.0 : f64
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64):
+        %cmp = arith.cmpf "ogt", %x, %cf0 : f64
+        sparse_tensor.yield %cmp : i1
+    }
+  return %r : f64
+// -----
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
 // CHECK-LABEL: func @concat_sparse_sparse(
 //  CHECK-SAME:   %[[A0:.*]]: tensor<2x4xf64
 //  CHECK-SAME:   %[[A1:.*]]: tensor<3x4xf64


More information about the Mlir-commits mailing list