[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)
Han-Chung Wang
llvmlistbot at llvm.org
Fri Dec 1 11:00:05 PST 2023
================
@@ -471,6 +471,192 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
}
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
+ assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
+ auto tensorTypes =
+ llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
+ return llvm::cast<RankedTensorType>(type);
+ }));
+ int64_t concatRank = tensorTypes[0].getRank();
+
+ // The concatenation dim must be in the range [0, rank).
+ assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
+
+ SmallVector<int64_t> sizes((concatRank));
+ for (int64_t i = 0, e = concatRank; i < e; ++i) {
+ if (i == dim)
+ continue;
+ SaturatedInteger size;
+ for (auto tensorType : tensorTypes)
+ size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+ sizes[i] = size.asInteger();
+ }
+ auto concatSize = SaturatedInteger::wrap(0);
+ for (auto tensorType : tensorTypes)
+ concatSize =
+ concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+ sizes[dim] = concatSize.asInteger();
+ return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+ ValueRange inputs) {
+ FailureOr<RankedTensorType> resultType =
+ inferResultType(dim, inputs.getTypes());
+ assert(succeeded(resultType) && "failed to infer concatenation result type");
+ build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+ if (getInputs().size() < 1)
+ return emitOpError("requires at least one input");
+
+ SmallVector<RankedTensorType> inputTypes;
+ for (auto input : getInputs())
+ inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+ RankedTensorType resultType = getResultType();
+ int64_t resultRank = getRank();
+ if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+ return type.getRank() != resultRank;
+ }))
+ return emitOpError("rank of concatenated inputs must match result rank");
+
+ Type resultElementType = resultType.getElementType();
+ if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+ return type.getElementType() != resultElementType;
+ }))
+ return emitOpError("inputs and result element type must match");
+
+ int64_t dim = getDim();
+ if (dim >= resultRank)
+ return emitOpError("concatenation dim must be less than the tensor rank");
+
+ SmallVector<int64_t> sizes((resultRank));
----------------
hanhanW wrote:
ditto
```suggestion
SmallVector<int64_t> sizes(resultRank);
```
https://github.com/llvm/llvm-project/pull/72779
More information about the Mlir-commits
mailing list