[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 20 04:50:29 PST 2023
================
@@ -471,6 +471,202 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
}
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
+ TypeRange inputTypes) {
+ if (dim < 0)
+ return failure();
+
+ if (inputTypes.empty())
+ return failure();
+
+ RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
+ if (!init)
+ return failure();
+
+ // The tensor rank must be greater than the concatenation dim.
+ int64_t concatRank = init.getRank();
+ if (concatRank <= dim)
+ return failure();
+
+ SmallVector<int64_t> sizes(init.getShape());
+ Type elementType = init.getElementType();
+ for (Type type : inputTypes.drop_front()) {
+ RankedTensorType tensorType = dyn_cast<RankedTensorType>(type);
+ if (!tensorType || tensorType.getRank() != concatRank ||
+ tensorType.getElementType() != elementType)
+ return failure();
+
+ for (auto [index, currSize] : llvm::enumerate(tensorType.getShape())) {
+ int64_t size = sizes[index];
+ bool hasDynamic =
+ ShapedType::isDynamic(size) || ShapedType::isDynamic(currSize);
+ if (static_cast<int64_t>(index) == dim) {
+ sizes[index] = hasDynamic ? ShapedType::kDynamic : currSize + size;
+ continue;
+ }
+
+ // If the sizes are statically different for a dimension other than the
+ // concated dimension, the concatenation is invalid. Both dynamic or
+ // mixed dynamic and static is fine.
+ if (currSize != size && !hasDynamic)
+ return failure();
+
+ // If the new size is not dynamic, use the additional static information.
+ if (!ShapedType::isDynamic(currSize))
+ sizes[index] = currSize;
+ }
+ }
+
+ return RankedTensorType::get(sizes, elementType);
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+ ValueRange inputs) {
+ FailureOr<RankedTensorType> resultType =
+ inferResultType(dim, inputs.getTypes());
+ assert(succeeded(resultType) && "failed to infer concatenation result type");
+ build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+ if (getInputs().size() < 1)
+ return emitOpError("requires at least one input");
+
+ SmallVector<RankedTensorType> inputTypes;
+ for (auto input : getInputs())
+ inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+ RankedTensorType resultType = getResultType();
+
+ int64_t resultRank = resultType.getRank();
+ if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+ return type.getRank() != resultRank;
+ }))
+ return emitOpError("rank of concatenated inputs must match result rank");
+
+ Type resultElementType = resultType.getElementType();
+ if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+ return type.getElementType() != resultElementType;
+ }))
+ return emitOpError("inputs and result element type must match");
+
+ if (static_cast<int64_t>(getDim()) >= resultRank)
+ return emitOpError("concatenation dim must be less than the tensor rank");
+
+ FailureOr<RankedTensorType> inferredResultType =
+ inferResultType(getDim(), getInputs().getTypes());
+ if (failed(inferredResultType))
+ return emitOpError("failed to infer concatenation result type from inputs");
+
+ for (auto [inferredSize, actualSize] :
+ llvm::zip_equal(inferredResultType->getShape(), resultType.getShape())) {
+ bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
+ ShapedType::isDynamic(actualSize);
+ if (!hasDynamic && inferredSize != actualSize)
+ return emitOpError("result type ")
+ << resultType << "does not match inferred shape "
+ << *inferredResultType << " static sizes";
+ }
+
+ return success();
+}
+
+LogicalResult
+ConcatOp::reifyResultShapes(OpBuilder &builder,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ ValueRange inputs = getInputs();
+ int64_t dim = getDim();
+ FailureOr<RankedTensorType> maybeInferredResultType =
+ inferResultType(dim, inputs.getTypes());
+ if (failed(maybeInferredResultType))
+ return failure();
+ RankedTensorType inferredResultType = *maybeInferredResultType;
+
+ Value init = inputs[0];
+ int64_t rank = getType().getRank();
+
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
+
+ // Pre-populate the result sizes with as much static information as possible
+ // from the given result type, as well as the inferred result type, otherwise
+ // use the dim sizes from the first input.
+ bool hasStaticConcatDim = false;
+ for (int64_t i = 0; i < rank; ++i) {
----------------
nicolasvasilache wrote:
I would prob. also structure this differently to make it simpler to read:
1. for all `ranks != dim`, copy the shape of the first tensor (assuming the interplay with the verifier is good)
2. for `dim`, take the static result size of the sum.
Looking deeper, it seems similar to what you are implementing here, I'd just expect this to take ~10 lines and be very easy to parse to human eyes (vs the ~40 I see here).
https://github.com/llvm/llvm-project/pull/72779
More information about the Mlir-commits
mailing list