[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)

Han-Chung Wang llvmlistbot at llvm.org
Fri Dec 1 11:00:04 PST 2023


================
@@ -471,6 +471,192 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
+  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
+  auto tensorTypes =
+      llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
+        return llvm::cast<RankedTensorType>(type);
+      }));
+  int64_t concatRank = tensorTypes[0].getRank();
+
+  // The concatenation dim must be in the range [0, rank).
+  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
+
+  SmallVector<int64_t> sizes((concatRank));
+  for (int64_t i = 0, e = concatRank; i < e; ++i) {
+    if (i == dim)
+      continue;
+    SaturatedInteger size;
+    for (auto tensorType : tensorTypes)
+      size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+    sizes[i] = size.asInteger();
+  }
+  auto concatSize = SaturatedInteger::wrap(0);
+  for (auto tensorType : tensorTypes)
+    concatSize =
+        concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+  sizes[dim] = concatSize.asInteger();
+  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+                     ValueRange inputs) {
+  FailureOr<RankedTensorType> resultType =
+      inferResultType(dim, inputs.getTypes());
+  assert(succeeded(resultType) && "failed to infer concatenation result type");
+  build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+  if (getInputs().size() < 1)
+    return emitOpError("requires at least one input");
+
+  SmallVector<RankedTensorType> inputTypes;
+  for (auto input : getInputs())
+    inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+  RankedTensorType resultType = getResultType();
+  int64_t resultRank = getRank();
+  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+        return type.getRank() != resultRank;
+      }))
+    return emitOpError("rank of concatenated inputs must match result rank");
+
+  Type resultElementType = resultType.getElementType();
+  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+        return type.getElementType() != resultElementType;
+      }))
+    return emitOpError("inputs and result element type must match");
+
+  int64_t dim = getDim();
+  if (dim >= resultRank)
+    return emitOpError("concatenation dim must be less than the tensor rank");
+
+  SmallVector<int64_t> sizes((resultRank));
+  for (int64_t i = 0, e = resultRank; i < e; ++i) {
+    if (i == dim)
+      continue;
+    SaturatedInteger size;
+    for (auto tensorType : inputTypes) {
+      FailureOr<SaturatedInteger> maybeSize =
+          size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+      if (failed(maybeSize))
+        return emitOpError("static concatenation size mismatch along ")
+               << "non-concatenated dimension " << i;
+      size = *maybeSize;
+    }
+    sizes[i] = size.asInteger();
+  }
+  auto concatSize = SaturatedInteger::wrap(0);
+  for (auto tensorType : inputTypes)
+    concatSize =
+        concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+  sizes[dim] = concatSize.asInteger();
+  auto inferredResultType =
+      RankedTensorType::get(sizes, inputTypes[0].getElementType());
+
+  for (auto [inferredSize, actualSize] :
+       llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
+    bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
+                      ShapedType::isDynamic(actualSize);
+    if (!hasDynamic && inferredSize != actualSize)
+      return emitOpError("result type ")
+             << resultType << "does not match inferred shape "
+             << inferredResultType << " static sizes";
+  }
+
+  return success();
+}
+
+LogicalResult
+ConcatOp::reifyResultShapes(OpBuilder &builder,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  ValueRange inputs = getInputs();
+  int64_t dim = getDim();
+  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
+
+  Value init = inputs[0];
+  int64_t rank = getType().getRank();
+
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
+
+  // Pre-populate the result sizes with as much static information as possible
+  // from the given result type, as well as the inferred result type, otherwise
+  // use the dim sizes from the first input.
+  for (int64_t i = 0; i < rank; ++i) {
+    if (i == dim)
+      continue;
+    if (!getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+    } else if (!inferredResultType.isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] =
+          builder.getIndexAttr(inferredResultType.getDimSize(i));
+    } else {
+      reifiedReturnShapes[0][i] =
+          builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
+    }
+  }
+
+  // Take the sum of the input sizes along the concatenated dim.
+  AffineExpr sum = builder.getAffineDimExpr(0);
+  SmallVector<OpFoldResult> sizes{
+      builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
----------------
hanhanW wrote:

Can we use assignment syntax? The brace initialization with constructor syntax is dangerous. It could potentially lead to a bug, e.g.,

```
std::vector<std::string> strings{2}; // A vector of two empty strings.
std::vector<int> ints{2};            // A vector containing only the integer 2.
```

For more context, see https://abseil.io/tips/88

https://github.com/llvm/llvm-project/pull/72779


More information about the Mlir-commits mailing list