[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)

Quinn Dawkins llvmlistbot at llvm.org
Sat Nov 18 20:26:18 PST 2023


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/72779

This adds an operation for concatenating ranked tensors along a static dimension, as well as a decomposition mirroring the existing lowering from TOSA to Tensor. This offers a convergence point for "input" like dialects that include various lowerings for concatenation operations, easing later analysis. In the future, this op can implement the necessary interfaces for tiling, as well as potentially add conversions to some kind of linalg and/or memref counterpart.

This patch adds the op, the decomposition, and some basic folding/canonicalization. Replacing lowerings with the op (such as the TOSA lowering) will come as a follow up.

See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858

>From 28b4d32886c463462339ba9167e2206a50abb9af Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sat, 18 Nov 2023 17:02:40 -0500
Subject: [PATCH] [mlir][tensor] Add tensor.concat operation

This adds an operation for concatenating ranked tensors along a static
dimension, as well as a decomposition mirroring the existing lowering
from Tosa to Tensor. This offers for "input" like dialects that include
various lowerings for concatenation operations, easing later analysis.
In the future, this op can implement the necessary interfaces for
tiling, as well as potentially add conversions to some kind of linalg
and/or memref counterpart.

See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  60 ++++++
 .../Dialect/Tensor/Transforms/Transforms.h    |   5 +
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 191 ++++++++++++++++++
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   2 +
 .../Tensor/Transforms/ConcatOpPatterns.cpp    |  86 ++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  12 ++
 .../test/Dialect/Tensor/decompose-concat.mlir |  39 ++++
 mlir/test/Dialect/Tensor/invalid.mlir         |  48 +++++
 mlir/test/Dialect/Tensor/ops.mlir             |  17 ++
 .../Dialect/Tensor/TestTensorTransforms.cpp   |  12 ++
 10 files changed, 472 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
 create mode 100644 mlir/test/Dialect/Tensor/decompose-concat.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7ae27407a9526e7..9101b7fa7c04005 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -121,6 +121,66 @@ def Tensor_CastOp : Tensor_Op<"cast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ConcatOp : Tensor_Op<"concat",
+    [Pure,
+     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+  let summary = "tensor concatenation operation";
+  let description = [{
+    The "concat" operation constructs a tensor out of a variadic list of input
+    tensors, concatenated along a static dimension. All inputs and the result
+    type must share the same rank.
+
+    `dim` specifies the dimension along which to concatenate. The size of the
+    concatenated dimension in the result must be equal to the sum of the sizes
+    of the inputs along that dimension. All other dimensions in both the inputs
+    and result must be the same size.
+
+    Example:
+
+    ```mlir
+    %0 = tensor.concat dim(0) %0, %1, %2 :
+        (tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>
+
+    // Dynamic + dynamic -> static
+    %0 = tensor.concat dim(1) %0, %1, %2 :
+        (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
+    ```
+  }];
+  let arguments = (ins I64Attr:$dim,
+                       Variadic<AnyRankedTensor>:$inputs);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    `dim` `(` $dim `)` $inputs attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let builders = [
+    // Builder with an inferred result type.
+    OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
+  ];
+
+  let extraClassDeclaration = [{
+    // Helper to infer the concatenated result type for the given list of input
+    // types, being concatenated along `dim`. Because concatenation can specify
+    // more static information than can automatically be inferred,
+    // InferTypeOpInterface is not used.
+    static FailureOr<RankedTensorType> inferResultType(int64_t dim, TypeRange inputTypes);
+
+    RankedTensorType getResultType() {
+      return ::llvm::cast<RankedTensorType>(getResult().getType());
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 705b30e7ded4779..c1627d20c2b3b6c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,11 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
                                      bool foldSingleUseOnly = false);
 
+/// Populates `patterns` with patterns that decompose `tensor.concat` into
+/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
+/// of `tensor.insert_slice` operations on the inputs.
+void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index e469815496e1832..f74a2ef689e7d67 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -471,6 +471,197 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
+                                                      TypeRange inputTypes) {
+  if (dim < 0)
+    return failure();
+
+  if (inputTypes.empty())
+    return failure();
+
+  RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
+  if (!init)
+    return failure();
+
+  // The tensor rank must be greater than the concatenation dim.
+  int64_t concatRank = init.getRank();
+  if (concatRank <= dim)
+    return failure();
+
+  SmallVector<int64_t> sizes(init.getShape());
+  Type elementType = init.getElementType();
+  for (Type type : inputTypes.drop_front()) {
+    RankedTensorType tensorType = dyn_cast<RankedTensorType>(type);
+    if (!tensorType || tensorType.getRank() != concatRank ||
+        tensorType.getElementType() != elementType)
+      return failure();
+
+    for (auto [index, currSize] : llvm::enumerate(tensorType.getShape())) {
+      int64_t size = sizes[index];
+      bool hasDynamic =
+          ShapedType::isDynamic(size) || ShapedType::isDynamic(currSize);
+      if (static_cast<int64_t>(index) == dim) {
+        sizes[index] = hasDynamic ? ShapedType::kDynamic : currSize + size;
+        continue;
+      }
+
+      // If the sizes are statically different for a dimension other than the
+      // concated dimension, the concatenation is invalid. Both dynamic or
+      // mixed dynamic and static is fine.
+      if (currSize != size && !hasDynamic)
+        return failure();
+
+      // If the new size is not dynamic, use the additional static information.
+      if (!ShapedType::isDynamic(currSize))
+        sizes[index] = currSize;
+    }
+  }
+
+  return RankedTensorType::get(sizes, elementType);
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+                     ValueRange inputs) {
+  FailureOr<RankedTensorType> resultType =
+      inferResultType(dim, inputs.getTypes());
+  assert(succeeded(resultType) && "failed to infer concatenation result type");
+  build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+  if (getInputs().size() < 1)
+    return emitOpError("requires at least one input");
+
+  SmallVector<RankedTensorType> inputTypes;
+  for (auto input : getInputs())
+    inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+  RankedTensorType resultType = getResultType();
+
+  int64_t resultRank = resultType.getRank();
+  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+        return type.getRank() != resultRank;
+      }))
+    return emitOpError("rank of concatenated inputs must match result rank");
+
+  Type resultElementType = resultType.getElementType();
+  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+        return type.getElementType() != resultElementType;
+      }))
+    return emitOpError("inputs and result element type must match");
+
+  if (static_cast<int64_t>(getDim()) >= resultRank)
+    return emitOpError("concatenation dim must be less than the tensor rank");
+
+  FailureOr<RankedTensorType> inferredResultType =
+      inferResultType(getDim(), getInputs().getTypes());
+  if (failed(inferredResultType))
+    return emitOpError("failed to infer concatenation result type from inputs");
+
+  for (auto [inferredSize, actualSize] :
+       llvm::zip_equal(inferredResultType->getShape(), resultType.getShape())) {
+    bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
+                      ShapedType::isDynamic(actualSize);
+    if (!hasDynamic && inferredSize != actualSize)
+      return emitOpError("result type ")
+             << resultType << "does not match inferred shape "
+             << *inferredResultType << " static sizes";
+  }
+
+  return success();
+}
+
+LogicalResult
+ConcatOp::reifyResultShapes(OpBuilder &builder,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  ValueRange inputs = getInputs();
+  int64_t dim = getDim();
+  FailureOr<RankedTensorType> maybeInferredResultType =
+      inferResultType(dim, inputs.getTypes());
+  if (failed(maybeInferredResultType))
+    return failure();
+  RankedTensorType inferredResultType = *maybeInferredResultType;
+
+  Value init = inputs[0];
+  int64_t rank = getType().getRank();
+
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
+
+  // Pre-populate the result sizes with as much static information as possible
+  // from the given result type, as well as the inferred result type, otherwise
+  // use the dim sizes from the first input.
+  bool hasStaticConcatDim = false;
+  for (int64_t i = 0; i < rank; ++i) {
+    if (!getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+      if (i == dim)
+        hasStaticConcatDim = true;
+    } else if (!inferredResultType.isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] =
+          builder.getIndexAttr(inferredResultType.getDimSize(i));
+      if (i == dim)
+        hasStaticConcatDim = true;
+    } else {
+      reifiedReturnShapes[0][i] =
+          tensor::getMixedSize(builder, init.getLoc(), init, i);
+    }
+  }
+
+  // Check if we already know the size of the concatenation dim statically.
+  if (hasStaticConcatDim)
+    return success();
+
+  // Take the sum of the input sizes along the concatenated dim.
+  Value concatValue = getValueOrCreateConstantIndexOp(
+      builder, getLoc(), reifiedReturnShapes[0][dim]);
+
+  for (Value input : inputs.drop_front()) {
+    Value newSize =
+        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim);
+    concatValue =
+        builder.create<arith::AddIOp>(input.getLoc(), concatValue, newSize);
+  }
+  reifiedReturnShapes[0][dim] = concatValue;
+  return success();
+}
+
+void ConcatOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "concat");
+}
+
+OpFoldResult ConcatOp::fold(FoldAdaptor) {
+  ValueRange inputs = getInputs();
+  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
+    return inputs[0];
+  return {};
+}
+
+namespace {
+/// Fold a concat op with a single input to a cast.
+struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    if (concatOp.getInputs().size() != 1)
+      return failure();
+    rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+                                        concatOp.getInputs()[0]);
+    return success();
+  }
+};
+} // namespace
+
+void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<SingleInputConcatOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c5fd4e65bbf7028..d233ab7a0e89741 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
   FoldIntoPackAndUnpackPatterns.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRAffineTransforms
   MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRIR
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
new file mode 100644
index 000000000000000..6fa34e5dd2b3f3e
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -0,0 +1,86 @@
+//===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts.
+///
+/// %concat = tensor.concat dim(1) %0, %1 :
+///         (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32>
+///
+/// Becomes
+///
+/// %empty = tensor.empty() : tensor<2x7xf32>
+/// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1]
+/// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1]
+struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = concatOp.getLoc();
+    FailureOr<Value> dest =
+        tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+    if (failed(dest))
+      return failure();
+
+    auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+    if (!empty)
+      return failure();
+
+    int64_t dim = concatOp.getDim();
+    Value dimValue = rewriter.createOrFold<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(dim));
+
+    int64_t rank = concatOp.getResultType().getRank();
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> dimOffsets;
+    dimOffsets.push_back(rewriter.getIndexAttr(0));
+
+    for (auto input : concatOp.getInputs().drop_front()) {
+      Value size = rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue);
+      Value currentOffset =
+          getValueOrCreateConstantIndexOp(rewriter, loc, dimOffsets.back());
+      Value total =
+          rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
+      dimOffsets.push_back(getAsOpFoldResult(total));
+    }
+
+    Value result = *dest;
+
+    for (auto [input, offset] : llvm::zip(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> sizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      offsets[dim] = offset;
+      result = rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, input, result, offsets, sizes, strides);
+    }
+
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        concatOp, concatOp.getResultType(), result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateDecomposeTensorConcatPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ea8c17640d7c143..b77197b6cd9f1e9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32>
 
 // -----
 
+// CHECK-LABEL: fold_concat
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
+func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
+  %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
+  // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
+  %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
+  // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+  return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract
 func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
new file mode 100644
index 000000000000000..b4f6a25f334e3b1
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-decompose-concat -cse  %s | FileCheck %s
+
+func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @decompose_dynamic_concat(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+
+//   CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT_SIZE:.+]] = arith.addi %[[DIM]], %[[C4]] : index
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
+//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
+//       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, %[[DIM]]] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
+//       CHECK:     %[[CAST_BACK:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
+//       CHECK:     return %[[CAST_BACK]] : tensor<?x?xf32>
+
+// -----
+
+func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
+                            %arg1 : tensor<2xf32>,
+                            %arg2 : tensor<3xf32>,
+                            %arg3: tensor<4xf32>) -> tensor<10xf32> {
+  %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3
+             : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+// CHECK-LABEL: func @decompose_1d_concat
+//       CHECK:    tensor.empty() : tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[2] [2] [1] : tensor<2xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[5] [3] [1] : tensor<3xf32> into tensor<10xf32>
+//       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[9] [4] [1] : tensor<4xf32> into tensor<10xf32>
+//       CHECK:    return %[[CONCAT]] : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 389e7e675c0eeda..426372f0ba4e746 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
 
 // -----
 
+func.func @concat_empty() {
+  // expected-error at +1 {{requires at least one input}}
+  %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) {
+  // expected-error at +1 {{rank of concatenated inputs must match result rank}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
+  return
+}
+
+// -----
+
+func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) {
+  // expected-error at +1 {{concatenation dim must be less than the tensor rank}}
+  %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) {
+  // expected-error at +1 {{inputs and result element type must match}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) {
+  // expected-error at +1 {{failed to infer concatenation result type from inputs}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32>
+  return
+}
+
+// -----
+
+func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
+  // expected-error at +1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}}
+  %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32>
+  return
+}
+
+// -----
+
 func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
   // expected-error at +1 {{incorrect number of indices for extract_element}}
   %0 = tensor.extract %arg0[] : tensor<?xf32>
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 71a0489b23f5f2d..2282da38803af0b 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
 
 // -----
 
+// CHECK-LABEL: func @concat(
+func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) {
+  // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @empty(
 //  CHECK-SAME:             %[[sz:.*]]: index
 func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 3e142155df8d9b8..bcddfd313ef0484 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,6 +77,10 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
       llvm::cl::init(false)};
 
+  Option<bool> testDecomposeConcat{
+      *this, "test-decompose-concat",
+      llvm::cl::desc("Test decomposing tensor.concat"), llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -108,6 +112,12 @@ static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyDecomposeConcatPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateDecomposeTensorConcatPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::ControlConstantExtractSliceFusionFn controlFn =
@@ -388,6 +398,8 @@ void TestTensorTransforms::runOnOperation() {
     applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testFoldIntoPackAndUnpack)
     applyFoldIntoPackAndUnpackPatterns(rootOp);
+  if (testDecomposeConcat)
+    applyDecomposeConcatPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))



More information about the Mlir-commits mailing list