[Mlir-commits] [mlir] de90713 - [mlir][sparse] Add new concatente operator to sparse tensor
Peiming Liu
llvmlistbot at llvm.org
Mon Aug 8 10:23:49 PDT 2022
Author: Peiming Liu
Date: 2022-08-08T17:23:43Z
New Revision: de907138ec96de063660b91a8adc7f28aa1bea98
URL: https://github.com/llvm/llvm-project/commit/de907138ec96de063660b91a8adc7f28aa1bea98
DIFF: https://github.com/llvm/llvm-project/commit/de907138ec96de063660b91a8adc7f28aa1bea98.diff
LOG: [mlir][sparse] Add new concatente operator to sparse tensor
See https://www.tensorflow.org/xla/operation_semantics#concatenate for the operator semantics
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D131111
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 9f99b52291177..db0fa64756897 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -163,6 +163,34 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
let hasVerifier = 1;
}
+def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", []>,
+ Arguments<(ins Variadic<AnyRankedTensor>:$inputs,
+ IndexAttr:$dimension)>,
+ Results<(outs AnyRankedTensor:$result)> {
+
+ let summary = "Concatenates a list of tensors into a single tensor.";
+ let description = [{
+ Concatenates a list input tensors and the output tensor with the same rank.
+ The concatenation happens on the specified `dimension` (0<= dimension < rank).
+ The resulting `dimension` size is the sum of all the input dimension sizes,
+ while all the other dimensions should have the same size in the input and
+ output tensors.
+
+ Only statically-sized input tensors are accepted, while the output tensor
+ can be dynamically-sized.
+
+ Example:
+
+ ```mlir
+ %0 = sparse_tensor.concatenate %1, %2 { dimension = 0 : index }
+ : tensor<64x64xf64, #CSR>, tensor<64x64xf64, #CSR> to tensor<128x64xf64, #CSR>
+ ```
+ }];
+
+ let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Management Operations. These operations are "impure" in the
// sense that they do not properly operate on SSA values. Instead, the behavior
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 8092aff006d46..34d26abcc4666 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
@@ -352,6 +353,69 @@ LogicalResult UnaryOp::verify() {
return success();
}
+LogicalResult ConcatenateOp::verify() {
+ auto dstTp = getType().cast<RankedTensorType>();
+ uint64_t concatDim = getDimension().getZExtValue();
+ unsigned rank = dstTp.getRank();
+
+ if (getInputs().size() <= 1)
+ return emitError("Need at least two tensors to concatenate.");
+
+ for (auto type : getInputs().getTypes()) {
+ auto shape = type.cast<RankedTensorType>().getShape();
+ for (auto dim : shape) {
+ if (dim == ShapedType::kDynamicSize)
+ return emitError("Only statically-sized input tensors are supported.");
+ }
+ }
+
+ if (concatDim >= rank)
+ return emitError(llvm::formatv(
+ "Failed to concatentate tensors with rank={0} on dimension={1}.", rank,
+ concatDim));
+
+ for (size_t i = 0; i < getInputs().size(); i++) {
+ Value input = getInputs()[i];
+ auto inputRank = input.getType().cast<RankedTensorType>().getRank();
+ if (inputRank != rank)
+ return emitError(
+ llvm::formatv("The input tensor ${0} has a
diff erent rank (rank={1}) "
+ "from the output tensor (rank={2}).",
+ i, inputRank, rank));
+ }
+
+ for (unsigned i = 0; i < rank; i++) {
+ auto dstDim = dstTp.getShape()[i];
+ if (i == concatDim) {
+ if (dstDim != ShapedType::kDynamicSize) {
+ unsigned sumDim = 0;
+ for (auto src : getInputs()) {
+ // If we reach here, all inputs should have static shapes.
+ auto d = src.getType().cast<RankedTensorType>().getShape()[i];
+ sumDim += d;
+ }
+ // If all dimension are statically known, the sum of all the input
+ // dimensions should be equal to the output dimension.
+ if (sumDim != dstDim)
+ return emitError(
+ "The concatenation dimension of the output tensor should be the "
+ "sum of all the concatenation dimensions of the input tensors.");
+ }
+ } else {
+ int prev = dstDim;
+ for (auto src : getInputs()) {
+ auto d = src.getType().cast<RankedTensorType>().getShape()[i];
+ if (prev != ShapedType::kDynamicSize && d != prev)
+ return emitError("All dimensions (expect for the concatenating one) "
+ "should be equal.");
+ prev = d;
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult ReduceOp::verify() {
Type inputType = getX().getType();
LogicalResult regionResult = success();
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index 31f5654f63580..d9b48fe2240fa 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -360,3 +360,86 @@ func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
}
return %r : f64
}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+ // expected-error at +1 {{Need at least two tensors to concatenate.}}
+ %0 = sparse_tensor.concatenate %arg {dimension = 1 : index}
+ : tensor<9x4xf64, #DC> to tensor<9x4xf64, #DC>
+ return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_dim(%arg0: tensor<2x4xf64, #DC>,
+ %arg1: tensor<3x4xf64, #DC>,
+ %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+ // expected-error at +1 {{Failed to concatentate tensors with rank=2 on dimension=4}}
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 4 : index}
+ : tensor<2x4xf64, #DC>,
+ tensor<3x4xf64, #DC>,
+ tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+ return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#C = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#DCC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed", "compressed"]}>
+func.func @invalid_concat_rank_mismatch(%arg0: tensor<2xf64, #C>,
+ %arg1: tensor<3x4xf64, #DC>,
+ %arg2: tensor<4x4x4xf64, #DCC>) -> tensor<9x4xf64, #DC> {
+ // expected-error at +1 {{The input tensor $0 has a
diff erent rank (rank=1) from the output tensor (rank=2)}}
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2xf64, #C>,
+ tensor<3x4xf64, #DC>,
+ tensor<4x4x4xf64, #DCC> to tensor<9x4xf64, #DC>
+ return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_size_mismatch_dyn(%arg0: tensor<?x4xf64, #DC>,
+ %arg1: tensor<5x4xf64, #DC>,
+ %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+ // expected-error at +1 {{Only statically-sized input tensors are supported.}}
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<?x4xf64, #DC>,
+ tensor<5x4xf64, #DC>,
+ tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+ return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_size_mismatch(%arg0: tensor<3x4xf64, #DC>,
+ %arg1: tensor<5x4xf64, #DC>,
+ %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+ // expected-error at +1 {{The concatenation dimension of the output tensor should be the sum of}}
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<3x4xf64, #DC>,
+ tensor<5x4xf64, #DC>,
+ tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+ return %0 : tensor<9x4xf64, #DC>
+}
+
+// -----
+
+#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
+ %arg1: tensor<3x3xf64, #DC>,
+ %arg2: tensor<4x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
+ // expected-error at +1 {{All dimensions (expect for the concatenating one) should be equal}}
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #DC>,
+ tensor<3x3xf64, #DC>,
+ tensor<4x4xf64, #DC> to tensor<9x4xf64, #DC>
+ return %0 : tensor<9x4xf64, #DC>
+}
+
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 75fc964a32ce0..5edc977de7c00 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -289,4 +289,28 @@ func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 {
sparse_tensor.yield %x : f64
}
return %r : f64
-}
\ No newline at end of file
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
+// CHECK-LABEL: func @concat_sparse_sparse(
+// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64
+// CHECK-SAME: %[[A1:.*]]: tensor<3x4xf64
+// CHECK-SAME: %[[A2:.*]]: tensor<4x4xf64
+// CHECK: %[[TMP0:.*]] = sparse_tensor.concatenate %[[A0]], %[[A1]], %[[A2]] {dimension = 0 : index} :
+// CHECK-SAME: tensor<2x4xf64
+// CHECK-SAME: tensor<3x4xf64
+// CHECK-SAME: tensor<4x4xf64
+// CHECK-SAME: tensor<9x4xf64
+// CHECK: return %[[TMP0]] : tensor<9x4xf64
+func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
+ %arg1: tensor<3x4xf64, #SparseMatrix>,
+ %arg2: tensor<4x4xf64, #SparseMatrix>) -> tensor<9x4xf64, #SparseMatrix> {
+ %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index}
+ : tensor<2x4xf64, #SparseMatrix>,
+ tensor<3x4xf64, #SparseMatrix>,
+ tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix>
+ return %0 : tensor<9x4xf64, #SparseMatrix>
+}
More information about the Mlir-commits
mailing list