[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)
Quinn Dawkins
llvmlistbot at llvm.org
Wed Nov 29 21:00:21 PST 2023
================
@@ -471,6 +471,202 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
}
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
+ TypeRange inputTypes) {
+ if (dim < 0)
+ return failure();
+
+ if (inputTypes.empty())
+ return failure();
+
+ RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
+ if (!init)
+ return failure();
+
+ // The tensor rank must be greater than the concatenation dim.
+ int64_t concatRank = init.getRank();
+ if (concatRank <= dim)
+ return failure();
+
+ SmallVector<int64_t> sizes(init.getShape());
+ Type elementType = init.getElementType();
+ for (Type type : inputTypes.drop_front()) {
----------------
qedawkins wrote:
I moved it and reused it. Let me know if that's what you were looking for (and if the naming makes sense).
https://github.com/llvm/llvm-project/pull/72779
More information about the Mlir-commits
mailing list