[Mlir-commits] [mlir] 414ed01 - [mlir][sparse] Introduce new binary and unary op

Jim Kitchen llvmlistbot at llvm.org
Thu Mar 17 10:31:24 PDT 2022


Author: Jim Kitchen
Date: 2022-03-17T12:31:09-05:00
New Revision: 414ed019acba67e914583fdc51eb3d1328696114

URL: https://github.com/llvm/llvm-project/commit/414ed019acba67e914583fdc51eb3d1328696114
DIFF: https://github.com/llvm/llvm-project/commit/414ed019acba67e914583fdc51eb3d1328696114.diff

LOG: [mlir][sparse] Introduce new binary and unary op

When the sparse_tensor dialect lowers linalg.generic,
it makes inferences about how the operations should
affect the looping logic. For example, multiplication
is an intersection while addition is a union of two
sparse tensors.

The new binary and unary op separate the looping logic
from the computation by nesting the computation code
inside a block which is merged at the appropriate level
in the lowered looping code.

The binary op can have custom computation code for the
overlap, left, and right sparse overlap regions. The
unary op can have custom computation code for the
present and absent values.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D121018

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0000714fc50cf..88b87b824c97b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -379,4 +379,216 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Custom Linalg.Generic Operations.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [NoSideEffect]>,
+    Arguments<(ins AnyType:$x, AnyType:$y, UnitAttr:$left_identity, UnitAttr:$right_identity)>,
+    Results<(outs AnyType:$output)> {
+  let summary = "Binary set operation utilized within linalg.generic";
+  let description = [{
+      Defines a computation within a `linalg.generic` operation that takes two
+      operands and executes one of the regions depending on whether both operands
+      or either operand is nonzero (i.e. stored explicitly in the sparse storage
+      format).
+
+      Three regions are defined for the operation and must appear in this order:
+      - overlap (elements present in both sparse tensors)
+      - left (elements only present in the left sparse tensor)
+      - right (element only present in the right sparse tensor)
+
+      Each region contains a single block describing the computation and result.
+      Every non-empty block must end with a sparse_tensor.yield and the return
+      type must match the type of `output`. The primary region's block has two
+      arguments, while the left and right region's block has only one argument.
+
+      A region may also be declared empty (i.e. `left={}`), indicating that the
+      region does not contribute to the output. For example, setting both
+      `left={}` and `right={}` is equivalent to the intersection of the two
+      inputs as only the overlap region will contribute values to the output.
+
+      As a convenience, there is also a special token `identity` which can be
+      used in place of the left or right region. This token indicates that
+      the return value is the input value (i.e. func(%x) => return %x).
+      As a practical example, setting `left=identity` and `right=identity`
+      would be equivalent to a union operation where non-overlapping values
+      in the inputs are copied to the output unchanged.
+
+      Example of isEqual applied to intersecting elements only:
+      ```mlir
+      %C = sparse_tensor.init...
+      %0 = linalg.generic #trait
+        ins(%A: tensor<?xf64, #SparseVec>, %B: tensor<?xf64, #SparseVec>)
+        outs(%C: tensor<?xi8, #SparseVec>) {
+        ^bb0(%a: f64, %b: f64, %c: i8) :
+          %result = sparse_tensor.binary %a, %b : f64, f64 to i8
+            overlap={
+              ^bb0(%arg0: f64, %arg1: f64):
+                %cmp = arith.cmpf "oeq", %arg0, %arg1 : f64
+                %ret_i8 = arith.extui %cmp : i1 to i8
+                sparse_tensor.yield %ret_i8 : i8
+            }
+            left={}
+            right={}
+          linalg.yield %result : i8
+      } -> tensor<?xi8, #SparseVec>
+      ```
+
+      Example of A+B in upper triangle, A-B in lower triangle:
+      ```mlir
+      %C = sparse_tensor.init...
+      %1 = linalg.generic #trait
+        ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64, #CSR>
+        outs(%C: tensor<?x?xf64, #CSR> {
+        ^bb0(%a: f64, %b: f64, %c: f64) :
+          %row = linalg.index 0 : index
+          %col = linalg.index 1 : index
+          %result = sparse_tensor.binary %a, %b : f64, f64 to f64
+            overlap={
+              ^bb0(%x: f64, %y: f64):
+                %cmp = arith.cmpi "uge", %column, %row : index
+                %upperTriangleResult = arith.addf %x, %y : f64
+                %lowerTriangleResult = arith.subf %x, %y : f64
+                %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64
+                sparse_tensor.yield %ret : f64
+            }
+            left=identity
+            right={
+              ^bb0(%y: f64):
+                %cmp = arith.cmpi "uge", %column, %row : index
+                %lowerTriangleResult = arith.negf %y : f64
+                %ret = arith.select %cmp, %y, %lowerTriangleResult
+                sparse_tensor.yield %ret : f64
+            }
+          linalg.yield %result : f64
+      } -> tensor<?x?xf64, #CSR>
+      ```
+
+      Example of set 
diff erence. Returns a copy of A where its sparse structure
+      is *not* overlapped by B. The element type of B can be 
diff erent than A
+      because we never use its values, only its sparse structure.
+      ```mlir
+      %C = sparse_tensor.init...
+      %2 = linalg.generic #trait
+        ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xi32, #CSR>
+        outs(%C: tensor<?x?xf64, #CSR> {
+        ^bb0(%a: f64, %b: i32, %c: f64) :
+          %result = sparse_tensor.binary %a, %b : f64, i32 to f64
+            overlap={}
+            left=identity
+            right={}
+          linalg.yield %result : f64
+      } -> tensor<?x?xf64, #CSR>
+      ```
+  }];
+
+  let regions = (region AnyRegion:$overlapRegion, AnyRegion:$leftRegion, AnyRegion:$rightRegion);
+  let assemblyFormat = [{
+        $x `,` $y `:` attr-dict type($x) `,` type($y) `to` type($output) `\n`
+        `overlap` `=` $overlapRegion `\n`
+        `left` `=` (`identity` $left_identity^):($leftRegion)? `\n`
+        `right` `=` (`identity` $right_identity^):($rightRegion)?
+  }];
+  let hasVerifier = 1;
+}
+
+def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [NoSideEffect]>,
+    Arguments<(ins AnyType:$x)>,
+    Results<(outs AnyType:$output)> {
+  let summary = "Unary set operation utilized within linalg.generic";
+  let description = [{
+      Defines a computation with a `linalg.generic` operation that takes a single
+      operand and executes one of two regions depending on whether the operand is
+      nonzero (i.e. stored explicitly in the sparse storage format).
+
+      Two regions are defined for the operation must appear in this order:
+      - present (elements present in the sparse tensor)
+      - absent (elements not present in the sparse tensor)
+
+      Each region contains a single block describing the computation and result.
+      A non-empty block must end with a sparse_tensor.yield and the return type
+      must match the type of `output`. The primary region's block has one
+      argument, while the missing region's block has zero arguments.
+
+      A region may also be declared empty (i.e. `absent={}`), indicating that the
+      region does not contribute to the output.
+
+      Example of A+1, restricted to existing elements:
+      ```mlir
+      %C = sparse_tensor.init...
+      %0 = linalg.generic #trait
+        ins(%A: tensor<?xf64, #SparseVec>)
+        outs(%C: tensor<?xf64, #SparseVec>) {
+        ^bb0(%a: f64, %c: f64) :
+          %result = sparse_tensor.unary %a : f64 to f64
+            present={
+              ^bb0(%arg0: f64):
+                %cf1 = arith.constant 1.0 : f64
+                %ret = arith.addf %arg0, %cf1 : f64
+                sparse_tensor.yield %ret : f64
+            }
+            absent={}
+          linalg.yield %result : f64
+      } -> tensor<?xf64, #SparseVec>
+      ```
+
+      Example returning +1 for existing values and -1 for missing values:
+      ```mlir
+      %result = sparse_tensor.unary %a : f64 to i32
+        present={
+          ^bb0(%x: f64):
+            %ret = arith.constant 1 : i32
+            sparse_tensor.yield %ret : i32
+        }
+        absent={
+          %ret = arith.constant -1 : i32
+          sparse_tensor.yield %ret : i32
+        }
+      ```
+
+      Example showing a structural inversion (existing values become missing in
+      the output, while missing values are filled with 1):
+      ```mlir
+      %result = sparse_tensor.unary %a : f64 to i64
+        present={}
+        absent={
+          %ret = arith.constant 1 : i64
+          sparse_tensor.yield %ret : i64
+        }
+      ```
+  }];
+
+  let regions = (region AnyRegion:$presentRegion, AnyRegion:$absentRegion);
+  let assemblyFormat = [{
+        $x attr-dict `:` type($x) `to` type($output) `\n`
+        `present` `=` $presentRegion `\n`
+        `absent` `=` $absentRegion
+  }];
+  let hasVerifier = 1;
+}
+
+def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
+    Arguments<(ins AnyType:$result)> {
+  let summary = "Yield from sparse_tensor set-like operations";
+  let description = [{
+      Yields a value from within a `binary` or `unary` block.
+
+      Example:
+      ```
+      %0 = sparse_tensor.unary %a : i64 to i64 {
+        ^bb0(%arg0: i64):
+          %cst = arith.constant 1 : i64
+          %ret = arith.addi %arg0, %cst : i64
+          sparse_tensor.yield %ret : i64
+      }
+      ```
+  }];
+
+  let assemblyFormat = [{
+        $result attr-dict `:` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // SPARSETENSOR_OPS

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 91e23d9076d81..0804280b231b9 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -333,6 +333,115 @@ LogicalResult OutOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// TensorDialect Linalg.Generic Operations.
+//===----------------------------------------------------------------------===//
+
+template <class T>
+static LogicalResult verifyNumBlockArgs(T *op, Region &region,
+                                        const char *regionName,
+                                        TypeRange inputTypes, Type outputType) {
+  unsigned numArgs = region.getNumArguments();
+  unsigned expectedNum = inputTypes.size();
+  if (numArgs != expectedNum)
+    return op->emitError() << regionName << " region must have exactly "
+                           << expectedNum << " arguments";
+
+  for (unsigned i = 0; i < numArgs; i++) {
+    Type typ = region.getArgument(i).getType();
+    if (typ != inputTypes[i])
+      return op->emitError() << regionName << " region argument " << (i + 1)
+                             << " type mismatch";
+  }
+  Operation *term = region.front().getTerminator();
+  YieldOp yield = dyn_cast<YieldOp>(term);
+  if (!yield)
+    return op->emitError() << regionName
+                           << " region must end with sparse_tensor.yield";
+  if (yield.getOperand().getType() != outputType)
+    return op->emitError() << regionName << " region yield type mismatch";
+
+  return success();
+}
+
+LogicalResult BinaryOp::verify() {
+  NamedAttrList attrs = (*this)->getAttrs();
+  Type leftType = x().getType();
+  Type rightType = y().getType();
+  Type outputType = output().getType();
+  Region &overlap = overlapRegion();
+  Region &left = leftRegion();
+  Region &right = rightRegion();
+
+  // Check correct number of block arguments and return type for each
+  // non-empty region.
+  LogicalResult regionResult = success();
+  if (!overlap.empty()) {
+    regionResult = verifyNumBlockArgs(
+        this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
+    if (failed(regionResult))
+      return regionResult;
+  }
+  if (!left.empty()) {
+    regionResult =
+        verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
+    if (failed(regionResult))
+      return regionResult;
+  } else if (left_identity()) {
+    if (leftType != outputType)
+      return emitError("left=identity requires first argument to have the same "
+                       "type as the output");
+  }
+  if (!right.empty()) {
+    regionResult = verifyNumBlockArgs(this, right, "right",
+                                      TypeRange{rightType}, outputType);
+    if (failed(regionResult))
+      return regionResult;
+  } else if (right_identity()) {
+    if (rightType != outputType)
+      return emitError("right=identity requires second argument to have the "
+                       "same type as the output");
+  }
+
+  return success();
+}
+
+LogicalResult UnaryOp::verify() {
+  Type inputType = x().getType();
+  Type outputType = output().getType();
+  LogicalResult regionResult = success();
+
+  // Check correct number of block arguments and return type for each
+  // non-empty region.
+  Region &present = presentRegion();
+  if (!present.empty()) {
+    regionResult = verifyNumBlockArgs(this, present, "present",
+                                      TypeRange{inputType}, outputType);
+    if (failed(regionResult))
+      return regionResult;
+  }
+  Region &absent = absentRegion();
+  if (!absent.empty()) {
+    regionResult =
+        verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
+    if (failed(regionResult))
+      return regionResult;
+  }
+
+  return success();
+}
+
+LogicalResult YieldOp::verify() {
+  // Check for compatible parent.
+  auto *parentOp = (*this)->getParentOp();
+  if (auto binaryOp = dyn_cast<BinaryOp>(parentOp))
+    return success();
+  if (auto unaryOp = dyn_cast<UnaryOp>(parentOp))
+    return success();
+
+  return emitOpError("expected parent op to be sparse_tensor binary or unary");
+}
+
 //===----------------------------------------------------------------------===//
 // TensorDialect Methods.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 84990221e4df4..a6fdb37262ae2 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -212,3 +212,111 @@ func @invalid_out_dense(%arg0: tensor<10xf64>, %arg1: !llvm.ptr<i8>) {
   sparse_tensor.out %arg0, %arg1 : tensor<10xf64>, !llvm.ptr<i8>
   return
 }
+
+// -----
+
+func @invalid_binary_num_args_mismatch_overlap(%arg0: f64, %arg1: f64) -> f64 {
+  // expected-error at +1 {{overlap region must have exactly 2 arguments}}
+  %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64
+    overlap={
+      ^bb0(%x: f64):
+        sparse_tensor.yield %x : f64
+    }
+    left={}
+    right={}
+  return %r : f64
+}
+
+// -----
+
+func @invalid_binary_num_args_mismatch_right(%arg0: f64, %arg1: f64) -> f64 {
+  // expected-error at +1 {{right region must have exactly 1 arguments}}
+  %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64
+    overlap={}
+    left={}
+    right={
+      ^bb0(%x: f64, %y: f64):
+        sparse_tensor.yield %y : f64
+    }
+  return %r : f64
+}
+
+// -----
+
+func @invalid_binary_argtype_mismatch(%arg0: f64, %arg1: f64) -> f64 {
+  // expected-error at +1 {{overlap region argument 2 type mismatch}}
+  %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64
+    overlap={
+      ^bb0(%x: f64, %y: f32):
+        sparse_tensor.yield %x : f64
+    }
+    left=identity
+    right=identity
+  return %r : f64
+}
+
+// -----
+
+func @invalid_binary_wrong_return_type(%arg0: f64, %arg1: f64) -> f64 {
+  // expected-error at +1 {{left region yield type mismatch}}
+  %0 = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64
+    overlap={}
+    left={
+      ^bb0(%x: f64):
+        %1 = arith.constant 0.0 : f32
+        sparse_tensor.yield %1 : f32
+    }
+    right=identity
+  return %0 : f64
+}
+
+// -----
+
+func @invalid_binary_wrong_identity_type(%arg0: i64, %arg1: f64) -> f64 {
+  // expected-error at +1 {{left=identity requires first argument to have the same type as the output}}
+  %0 = sparse_tensor.binary %arg0, %arg1 : i64, f64 to f64
+    overlap={}
+    left=identity
+    right=identity
+  return %0 : f64
+}
+
+// -----
+
+func @invalid_unary_argtype_mismatch(%arg0: f64) -> f64 {
+  // expected-error at +1 {{present region argument 1 type mismatch}}
+  %r = sparse_tensor.unary %arg0 : f64 to f64
+    present={
+      ^bb0(%x: index):
+        sparse_tensor.yield %x : index
+    }
+    absent={}
+  return %r : f64
+}
+
+// -----
+
+func @invalid_unary_num_args_mismatch(%arg0: f64) -> f64 {
+  // expected-error at +1 {{absent region must have exactly 0 arguments}}
+  %r = sparse_tensor.unary %arg0 : f64 to f64
+    present={}
+    absent={
+      ^bb0(%x: f64):
+        sparse_tensor.yield %x : f64
+    }
+  return %r : f64
+}
+
+// -----
+
+func @invalid_unary_wrong_return_type(%arg0: f64) -> f64 {
+  // expected-error at +1 {{present region yield type mismatch}}
+  %0 = sparse_tensor.unary %arg0 : f64 to f64
+    present={
+      ^bb0(%x: f64):
+        %1 = arith.constant 0.0 : f32
+        sparse_tensor.yield %1 : f32
+    }
+    absent={}
+  return %0 : f64
+}

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 5457e55f57e6a..a7ea965410ab3 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -193,3 +193,94 @@ func @sparse_out(%arg0: tensor<?x?xf64, #SparseMatrix>, %arg1: !llvm.ptr<i8>) {
   sparse_tensor.out %arg0, %arg1 : tensor<?x?xf64, #SparseMatrix>, !llvm.ptr<i8>
   return
 }
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_binary(
+//  CHECK-SAME:   %[[A:.*]]: f64, %[[B:.*]]: i64) -> f64 {
+//       CHECK:   %[[Z:.*]] = arith.constant 0.000000e+00 : f64
+//       CHECK:   %[[C1:.*]] = sparse_tensor.binary %[[A]], %[[B]] : f64, i64 to f64
+//       CHECK:     overlap = {
+//       CHECK:       ^bb0(%[[A1:.*]]: f64, %[[B1:.*]]: i64):
+//       CHECK:         sparse_tensor.yield %[[A1]] : f64
+//       CHECK:     }
+//       CHECK:     left = identity
+//       CHECK:     right = {
+//       CHECK:       ^bb0(%[[A2:.*]]: i64):
+//       CHECK:         sparse_tensor.yield %[[Z]] : f64
+//       CHECK:     }
+//       CHECK:   return %[[C1]] : f64
+//       CHECK: }
+func @sparse_binary(%arg0: f64, %arg1: i64) -> f64 {
+  %cf0 = arith.constant 0.0 : f64
+  %r = sparse_tensor.binary %arg0, %arg1 : f64, i64 to f64
+    overlap={
+      ^bb0(%x: f64, %y: i64):
+        sparse_tensor.yield %x : f64
+    }
+    left=identity
+    right={
+      ^bb0(%y: i64):
+        sparse_tensor.yield %cf0 : f64
+    }
+  return %r : f64
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_unary(
+//  CHECK-SAME:   %[[A:.*]]: f64) -> f64 {
+//       CHECK:   %[[C1:.*]] = sparse_tensor.unary %[[A]] : f64 to f64
+//       CHECK:     present = {
+//       CHECK:       ^bb0(%[[A1:.*]]: f64):
+//       CHECK:         sparse_tensor.yield %[[A1]] : f64
+//       CHECK:     }
+//       CHECK:     absent = {
+//       CHECK:       %[[R:.*]] = arith.constant -1.000000e+00 : f64
+//       CHECK:       sparse_tensor.yield %[[R]] : f64
+//       CHECK:     }
+//       CHECK:   return %[[C1]] : f64
+//       CHECK: }
+func @sparse_unary(%arg0: f64) -> f64 {
+  %r = sparse_tensor.unary %arg0 : f64 to f64
+    present={
+      ^bb0(%x: f64):
+        sparse_tensor.yield %x : f64
+    } absent={
+      ^bb0:
+        %cf1 = arith.constant -1.0 : f64
+        sparse_tensor.yield %cf1 : f64
+    }
+  return %r : f64
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @sparse_unary(
+//  CHECK-SAME:   %[[A:.*]]: f64) -> i64 {
+//       CHECK:   %[[C1:.*]] = sparse_tensor.unary %[[A]] : f64 to i64
+//       CHECK:     present = {
+//       CHECK:       ^bb0(%[[A1:.*]]: f64):
+//       CHECK:         %[[R:.*]] = arith.fptosi %[[A1]] : f64 to i64
+//       CHECK:         sparse_tensor.yield %[[R]] : i64
+//       CHECK:     }
+//       CHECK:     absent = {
+//       CHECK:     }
+//       CHECK:   return %[[C1]] : i64
+//       CHECK: }
+func @sparse_unary(%arg0: f64) -> i64 {
+  %r = sparse_tensor.unary %arg0 : f64 to i64
+    present={
+      ^bb0(%x: f64):
+        %ret = arith.fptosi %x : f64 to i64
+        sparse_tensor.yield %ret : i64
+    }
+    absent={}
+  return %r : i64
+}


        


More information about the Mlir-commits mailing list