[Mlir-commits] [mlir] de90713 - [mlir][sparse] Add new concatente operator to sparse tensor

Peiming Liu llvmlistbot at llvm.org
Mon Aug 8 10:23:49 PDT 2022


Author: Peiming Liu
Date: 2022-08-08T17:23:43Z
New Revision: de907138ec96de063660b91a8adc7f28aa1bea98

URL: https://github.com/llvm/llvm-project/commit/de907138ec96de063660b91a8adc7f28aa1bea98
DIFF: https://github.com/llvm/llvm-project/commit/de907138ec96de063660b91a8adc7f28aa1bea98.diff

LOG: [mlir][sparse] Add new concatente operator to sparse tensor

See https://www.tensorflow.org/xla/operation_semantics#concatenate for the operator semantics

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D131111

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/test/Dialect/SparseTensor/invalid.mlir
    mlir/test/Dialect/SparseTensor/roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9f99b52291177..db0fa64756897 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -163,6 +163,34 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
   let hasVerifier = 1;
 }
 
+def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", []>,
+    Arguments<(ins Variadic<AnyRankedTensor>:$inputs,
+                   IndexAttr:$dimension)>,
+    Results<(outs AnyRankedTensor:$result)> {
+
+  let summary = "Concatenates a list of tensors into a single tensor.";
+  let description = [{
+     Concatenates a list input tensors and the output tensor with the same rank.
+     The concatenation happens on the specified `dimension` (0<= dimension < rank).
+     The resulting `dimension` size is the sum of all the input dimension sizes,
+     while all the other dimensions should have the same size in the input and
+     output tensors.
+
+     Only statically-sized input tensors are accepted, while the output tensor
+     can be dynamically-sized.
+
+     Example:
+
+     ```mlir
+     %0 = sparse_tensor.concatenate %1, %2 { dimension = 0 : index }
+       : tensor<64x64xf64, #CSR>, tensor<64x64xf64, #CSR> to tensor<128x64xf64, #CSR>
+     ```
+   }];
+
+  let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Sparse Tensor Management Operations. These operations are "impure" in the
 // sense that they do not properly operate on SSA values. Instead, the behavior

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8092aff006d46..34d26abcc4666 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -13,6 +13,7 @@
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
 
 using namespace mlir;
 using namespace mlir::sparse_tensor;
@@ -352,6 +353,69 @@ LogicalResult UnaryOp::verify() {
   return success();
 }
 
+LogicalResult ConcatenateOp::verify() {
+  auto dstTp = getType().cast<RankedTensorType>();
+  uint64_t concatDim = getDimension().getZExtValue();
+  unsigned rank = dstTp.getRank();
+
+  if (getInputs().size() <= 1)
+    return emitError("Need at least two tensors to concatenate.");
+
+  for (auto type : getInputs().getTypes()) {
+    auto shape = type.cast<RankedTensorType>().getShape();
+    for (auto dim : shape) {
+      if (dim == ShapedType::kDynamicSize)
+        return emitError("Only statically-sized input tensors are supported.");
+    }
+  }
+
+  if (concatDim >= rank)
+    return emitError(llvm::formatv(
+        "Failed to concatentate tensors with rank={0} on dimension={1}.", rank,
+        concatDim));
+
+  for (size_t i = 0; i < getInputs().size(); i++) {
+    Value input = getInputs()[i];
+    auto inputRank = input.getType().cast<RankedTensorType>().getRank();
+    if (inputRank != rank)
+      return emitError(
+          llvm::formatv("The input tensor ${0} has a 
diff erent rank (rank={1}) "
+                        "from the output tensor (rank={2}).",
+                        i, inputRank, rank));
+  }
+
+  for (unsigned i = 0; i < rank; i++) {
+    auto dstDim = dstTp.getShape()[i];
+    if (i == concatDim) {
+      if (dstDim != ShapedType::kDynamicSize) {
+        unsigned sumDim = 0;
+        for (auto src : getInputs()) {
+          // If we reach here, all inputs should have static shapes.
+          auto d = src.getType().cast<RankedTensorType>().getShape()[i];
+          sumDim += d;
+        }
+        // If all dimension are statically known, the sum of all the input
+        // dimensions should be equal to the output dimension.
+        if (sumDim != dstDim)
+          return emitError(
+              "The concatenation dimension of the output tensor should be the "
+              "sum of all the concatenation dimensions of the input tensors.");
+      }
+    } else {
+      int prev = dstDim;
+      for (auto src : getInputs()) {
+        auto d = src.getType().cast<RankedTensorType>().getShape()[i];
+        if (prev != ShapedType::kDynamicSize && d != prev)
+          return emitError("All dimensions (expect for the concatenating one) "
+                           "should be equal.");
+        prev = d;
+      }
+    }
+  }
+
+  return success();
+}
+
 LogicalResult ReduceOp::verify() {
   Type inputType = getX().getType();
   LogicalResult regionResult = success();

diff  --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 31f5654f63580..d9b48fe2240fa 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -360,3 +360,86 @@ func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
     }
   return %r : f64
 }
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+  // expected-error at +1 {{Need at least two tensors to concatenate.}}
+  %0 = sparse_tensor.concatenate %arg {dimension = 1 : index}
+       : tensor<9x4xf64, #DC> to tensor<9x4xf64, #DC>
+  return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_dim(%arg0: tensor<2x4xf64, #DC>,
+                              %arg1: tensor<3x4xf64, #DC>,
+                              %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+  // expected-error at +1 {{Failed to concatentate tensors with rank=2 on dimension=4}}
+  %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 4 : index}
+       : tensor<2x4xf64, #DC>,
+         tensor<3x4xf64, #DC>,
+         tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+  return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#C = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#DCC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed", "compressed"]}>
+func.func @invalid_concat_rank_mismatch(%arg0: tensor<2xf64, #C>,
+                                        %arg1: tensor<3x4xf64, #DC>,
+                                        %arg2: tensor<4x4x4xf64, #DCC>) -> tensor<9x4xf64, #DC> {
+  // expected-error at +1 {{The input tensor $0 has a 
diff erent rank (rank=1) from the output tensor (rank=2)}}
+  %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+       : tensor<2xf64, #C>,
+         tensor<3x4xf64, #DC>,
+         tensor<4x4x4xf64, #DCC> to tensor<9x4xf64, #DC>
+  return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_size_mismatch_dyn(%arg0: tensor<?x4xf64, #DC>,
+                                            %arg1: tensor<5x4xf64, #DC>,
+                                            %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+  // expected-error at +1 {{Only statically-sized input tensors are supported.}}
+  %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+       : tensor<?x4xf64, #DC>,
+         tensor<5x4xf64, #DC>,
+         tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+  return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_size_mismatch(%arg0: tensor<3x4xf64, #DC>,
+                                        %arg1: tensor<5x4xf64, #DC>,
+                                        %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+  // expected-error at +1 {{The concatenation dimension of the output tensor should be the sum of}}
+  %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+       : tensor<3x4xf64, #DC>,
+         tensor<5x4xf64, #DC>,
+         tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+  return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
+                                        %arg1: tensor<3x3xf64, #DC>,
+                                        %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+  // expected-error at +1 {{All dimensions (expect for the concatenating one) should be equal}}
+  %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+       : tensor<2x4xf64, #DC>,
+         tensor<3x3xf64, #DC>,
+         tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+  return %0 : tensor<9x4xf64, #DC>
+}
+

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 75fc964a32ce0..5edc977de7c00 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -289,4 +289,28 @@ func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 {
         sparse_tensor.yield %x : f64
     }
   return %r : f64
-}
\ No newline at end of file
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @concat_sparse_sparse(
+//  CHECK-SAME:   %[[A0:.*]]: tensor<2x4xf64
+//  CHECK-SAME:   %[[A1:.*]]: tensor<3x4xf64
+//  CHECK-SAME:   %[[A2:.*]]: tensor<4x4xf64
+//       CHECK:   %[[TMP0:.*]] = sparse_tensor.concatenate %[[A0]], %[[A1]], %[[A2]] {dimension = 0 : index} :
+//  CHECK-SAME:   tensor<2x4xf64
+//  CHECK-SAME:   tensor<3x4xf64
+//  CHECK-SAME:   tensor<4x4xf64
+//  CHECK-SAME:   tensor<9x4xf64
+//       CHECK:   return %[[TMP0]] : tensor<9x4xf64
+func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
+                                %arg1: tensor<3x4xf64, #SparseMatrix>,
+                                %arg2: tensor<4x4xf64, #SparseMatrix>) -> tensor<9x4xf64, #SparseMatrix> {
+  %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+       : tensor<2x4xf64, #SparseMatrix>,
+         tensor<3x4xf64, #SparseMatrix>,
+         tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
+  return %0 : tensor<9x4xf64, #SparseMatrix>
+}


        


More information about the Mlir-commits mailing list