[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)

Han-Chung Wang llvmlistbot at llvm.org
Mon Nov 27 15:18:36 PST 2023


================
@@ -471,6 +471,202 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
+                                                      TypeRange inputTypes) {
+  if (dim < 0)
+    return failure();
+
+  if (inputTypes.empty())
+    return failure();
+
+  RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
+  if (!init)
+    return failure();
+
+  // The tensor rank must be greater than the concatenation dim.
+  int64_t concatRank = init.getRank();
+  if (concatRank <= dim)
+    return failure();
+
+  SmallVector<int64_t> sizes(init.getShape());
+  Type elementType = init.getElementType();
+  for (Type type : inputTypes.drop_front()) {
+    RankedTensorType tensorType = dyn_cast<RankedTensorType>(type);
+    if (!tensorType || tensorType.getRank() != concatRank ||
+        tensorType.getElementType() != elementType)
+      return failure();
+
+    for (auto [index, currSize] : llvm::enumerate(tensorType.getShape())) {
+      int64_t size = sizes[index];
+      bool hasDynamic =
+          ShapedType::isDynamic(size) || ShapedType::isDynamic(currSize);
+      if (static_cast<int64_t>(index) == dim) {
+        sizes[index] = hasDynamic ? ShapedType::kDynamic : currSize + size;
+        continue;
+      }
+
+      // If the sizes are statically different for a dimension other than the
+      // concated dimension, the concatenation is invalid. Both dynamic or
+      // mixed dynamic and static is fine.
+      if (currSize != size && !hasDynamic)
+        return failure();
+
+      // If the new size is not dynamic, use the additional static information.
+      if (!ShapedType::isDynamic(currSize))
+        sizes[index] = currSize;
+    }
+  }
+
+  return RankedTensorType::get(sizes, elementType);
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+                     ValueRange inputs) {
+  FailureOr<RankedTensorType> resultType =
+      inferResultType(dim, inputs.getTypes());
+  assert(succeeded(resultType) && "failed to infer concatenation result type");
+  build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+  if (getInputs().size() < 1)
+    return emitOpError("requires at least one input");
+
+  SmallVector<RankedTensorType> inputTypes;
+  for (auto input : getInputs())
+    inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+  RankedTensorType resultType = getResultType();
+
+  int64_t resultRank = resultType.getRank();
----------------
hanhanW wrote:

[optional] I think adding a `getRank()` method to concate op is useful.

https://github.com/llvm/llvm-project/pull/72779


More information about the Mlir-commits mailing list