[Mlir-commits] [mlir] f310a5d - [mlir][tensor] Add a tensor.concat operation (#72779)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 1 12:05:37 PST 2023


Author: Quinn Dawkins
Date: 2023-12-01T15:05:29-05:00
New Revision: f310a5d2c13455f1d68f5654fa4258357bafeff6

URL: https://github.com/llvm/llvm-project/commit/f310a5d2c13455f1d68f5654fa4258357bafeff6
DIFF: https://github.com/llvm/llvm-project/commit/f310a5d2c13455f1d68f5654fa4258357bafeff6.diff

LOG: [mlir][tensor] Add a tensor.concat operation (#72779)

This adds an operation for concatenating ranked tensors along a static
dimension, as well as a decomposition mirroring the existing lowering
from TOSA to Tensor. This offers a convergence point for "input" like
dialects that include various lowerings for concatenation operations,
easing later analysis. In the future, this op can implement the
necessary interfaces for tiling, as well as potentially add conversions
to some kind of linalg and/or memref counterpart.

This patch adds the op, the decomposition, and some basic
folding/canonicalization. Replacing lowerings with the op (such as the
TOSA lowering) will come as a follow up.

See
https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858

Added: 
    mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
    mlir/test/Dialect/Tensor/decompose-concat.mlir

Modified: 
    mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
    mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
    mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
    mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
    mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
    mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/test/Dialect/Tensor/invalid.mlir
    mlir/test/Dialect/Tensor/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7ae27407a9526e7..f50e3464867be50 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -121,6 +121,70 @@ def Tensor_CastOp : Tensor_Op<"cast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ConcatOp : Tensor_Op<"concat",
+    [Pure,
+     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+  let summary = "tensor concatenation operation";
+  let description = [{
+    The "concat" operation constructs a tensor out of a variadic list of input
+    tensors, concatenated along a static dimension number. All inputs and the
+    result type must share the same rank.
+
+    `dim` specifies the dimension along which to concatenate. The size of the
+    concatenated dimension in the result must be equal to the sum of the sizes
+    of the inputs along that dimension. All other dimensions in both the inputs
+    and result must be the same size.
+
+    Example:
+
+    ```mlir
+    %0 = tensor.concat dim(0) %0, %1, %2 :
+        (tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>
+
+    // Dynamic + dynamic -> static
+    %0 = tensor.concat dim(1) %0, %1, %2 :
+        (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
+    ```
+  }];
+  let arguments = (ins I64Attr:$dim,
+                       Variadic<AnyRankedTensor>:$inputs);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    `dim` `(` $dim `)` $inputs attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let builders = [
+    // Builder with an inferred result type.
+    OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
+  ];
+
+  let extraClassDeclaration = [{
+    // Helper to infer the concatenated result type for the given list of input
+    // types, being concatenated along `dim`. Because concatenation can specify
+    // more static information than can automatically be inferred,
+    // InferTypeOpInterface is not used.
+    static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);
+
+    RankedTensorType getResultType() {
+      return ::llvm::cast<RankedTensorType>(getResult().getType());
+    }
+
+    int64_t getRank() {
+      return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 66c6021418b471c..8556d9570fd1200 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.decompose_concat",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor.concat ops should be decomposed into a chain of
+    tensor.insert_slice operations inserting into a materialized destination.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
 def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
     "apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

diff  --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 705b30e7ded4779..44b8377bd6aad99 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,13 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
                                      bool foldSingleUseOnly = false);
 
+/// Populates `patterns` with patterns that decompose `tensor.concat` into
+/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
+/// of `tensor.insert_slice` operations on the inputs. This is intended to be
+/// used as a fallback tensor -> tensor lowering that decomposes concat such
+/// that it can be bufferized into a sequence of copies.
+void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.

diff  --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index c2fbaea726abcbb..502ab93ddbfa7d7 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
                                          OpFoldResult step);
 
+/// Idiomatic saturated operations on values like offsets, sizes, and strides.
+struct SaturatedInteger {
+  static SaturatedInteger wrap(int64_t v) {
+    return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
+                                      : SaturatedInteger{false, v};
+  }
+  int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
+  FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
+    if (saturated && !other.saturated)
+      return other;
+    if (!saturated && !other.saturated && v != other.v)
+      return failure();
+    return *this;
+  }
+  bool operator==(SaturatedInteger other) {
+    return (saturated && other.saturated) ||
+           (!saturated && !other.saturated && v == other.v);
+  }
+  bool operator!=(SaturatedInteger other) { return !(*this == other); }
+  SaturatedInteger operator+(SaturatedInteger other) {
+    if (saturated || other.saturated)
+      return SaturatedInteger{true, 0};
+    return SaturatedInteger{false, other.v + v};
+  }
+  SaturatedInteger operator*(SaturatedInteger other) {
+    if (saturated || other.saturated)
+      return SaturatedInteger{true, 0};
+    return SaturatedInteger{false, other.v * v};
+  }
+  bool saturated = true;
+  int64_t v = 0;
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..dce96cca016ff8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -26,43 +26,6 @@
 using namespace mlir;
 using namespace mlir::memref;
 
-namespace {
-/// Idiomatic saturated operations on offsets, sizes and strides.
-namespace saturated_arith {
-struct Wrapper {
-  static Wrapper stride(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
-  }
-  static Wrapper offset(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
-  }
-  static Wrapper size(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
-  }
-  int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
-  int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
-  int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
-  bool operator==(Wrapper other) {
-    return (saturated && other.saturated) ||
-           (!saturated && !other.saturated && v == other.v);
-  }
-  bool operator!=(Wrapper other) { return !(*this == other); }
-  Wrapper operator+(Wrapper other) {
-    if (saturated || other.saturated)
-      return Wrapper{true, 0};
-    return Wrapper{false, other.v + v};
-  }
-  Wrapper operator*(Wrapper other) {
-    if (saturated || other.saturated)
-      return Wrapper{true, 0};
-    return Wrapper{false, other.v * v};
-  }
-  bool saturated;
-  int64_t v;
-};
-} // namespace saturated_arith
-} // namespace
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
@@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
     ReassociationIndices reassoc = std::get<0>(it);
     int64_t currentStrideToExpand = std::get<1>(it);
     for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
-      using saturated_arith::Wrapper;
       reverseResultStrides.push_back(currentStrideToExpand);
-      currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
-                               Wrapper::size(resultShape[shapeIndex--]))
-                                  .asStride();
+      currentStrideToExpand =
+          (SaturatedInteger::wrap(currentStrideToExpand) *
+           SaturatedInteger::wrap(resultShape[shapeIndex--]))
+              .asInteger();
     }
   }
   auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
@@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
   unsigned resultStrideIndex = resultStrides.size() - 1;
   for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
-    using saturated_arith::Wrapper;
-    auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
+    auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
-      stride = stride * Wrapper::size(srcShape[idx]);
+      stride = stride * SaturatedInteger::wrap(srcShape[idx]);
 
       // Both source and result stride must have the same static value. In that
       // case, we can be sure, that the dimensions are collapsible (because they
@@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
       // ops where obviously non-contiguous dims are collapsed, but accept ops
       // where we cannot be sure statically. Such ops may fail at runtime. See
       // the op documentation for details.
-      auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
+      auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
@@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
   SmallVector<int64_t> resultShape;
   resultShape.reserve(reassociation.size());
   for (const ReassociationIndices &group : reassociation) {
-    using saturated_arith::Wrapper;
-    auto groupSize = Wrapper::size(1);
+    auto groupSize = SaturatedInteger::wrap(1);
     for (int64_t srcDim : group)
-      groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
-    resultShape.push_back(groupSize.asSize());
+      groupSize =
+          groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
+    resultShape.push_back(groupSize.asInteger());
   }
 
   if (srcType.getLayout().isIdentity()) {
@@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   int64_t targetOffset = sourceOffset;
   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
-    using saturated_arith::Wrapper;
-    targetOffset =
-        (Wrapper::offset(targetOffset) +
-         Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
-            .asOffset();
+    targetOffset = (SaturatedInteger::wrap(targetOffset) +
+                    SaturatedInteger::wrap(staticOffset) *
+                        SaturatedInteger::wrap(targetStride))
+                       .asInteger();
   }
 
   // Compute target stride whose value is:
@@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   targetStrides.reserve(staticOffsets.size());
   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
-    using saturated_arith::Wrapper;
-    targetStrides.push_back(
-        (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
-            .asStride());
+    targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
+                             SaturatedInteger::wrap(staticStride))
+                                .asInteger());
   }
 
   // The type is now known.

diff  --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd9b82d2c553fae..02146e8257b38e3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -472,6 +472,192 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
+  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
+  auto tensorTypes =
+      llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
+        return llvm::cast<RankedTensorType>(type);
+      }));
+  int64_t concatRank = tensorTypes[0].getRank();
+
+  // The concatenation dim must be in the range [0, rank).
+  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
+
+  SmallVector<int64_t> sizes(concatRank);
+  for (int64_t i = 0, e = concatRank; i < e; ++i) {
+    if (i == dim)
+      continue;
+    SaturatedInteger size;
+    for (auto tensorType : tensorTypes)
+      size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+    sizes[i] = size.asInteger();
+  }
+  auto concatSize = SaturatedInteger::wrap(0);
+  for (auto tensorType : tensorTypes)
+    concatSize =
+        concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+  sizes[dim] = concatSize.asInteger();
+  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+                     ValueRange inputs) {
+  FailureOr<RankedTensorType> resultType =
+      inferResultType(dim, inputs.getTypes());
+  assert(succeeded(resultType) && "failed to infer concatenation result type");
+  build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+  if (getInputs().size() < 1)
+    return emitOpError("requires at least one input");
+
+  SmallVector<RankedTensorType> inputTypes;
+  for (auto input : getInputs())
+    inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+  RankedTensorType resultType = getResultType();
+  int64_t resultRank = getRank();
+  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+        return type.getRank() != resultRank;
+      }))
+    return emitOpError("rank of concatenated inputs must match result rank");
+
+  Type resultElementType = resultType.getElementType();
+  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+        return type.getElementType() != resultElementType;
+      }))
+    return emitOpError("inputs and result element type must match");
+
+  int64_t dim = getDim();
+  if (dim >= resultRank)
+    return emitOpError("concatenation dim must be less than the tensor rank");
+
+  SmallVector<int64_t> sizes(resultRank);
+  for (int64_t i = 0, e = resultRank; i < e; ++i) {
+    if (i == dim)
+      continue;
+    SaturatedInteger size;
+    for (auto tensorType : inputTypes) {
+      FailureOr<SaturatedInteger> maybeSize =
+          size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+      if (failed(maybeSize))
+        return emitOpError("static concatenation size mismatch along ")
+               << "non-concatenated dimension " << i;
+      size = *maybeSize;
+    }
+    sizes[i] = size.asInteger();
+  }
+  auto concatSize = SaturatedInteger::wrap(0);
+  for (auto tensorType : inputTypes)
+    concatSize =
+        concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+  sizes[dim] = concatSize.asInteger();
+  auto inferredResultType =
+      RankedTensorType::get(sizes, inputTypes[0].getElementType());
+
+  for (auto [inferredSize, actualSize] :
+       llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
+    bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
+                      ShapedType::isDynamic(actualSize);
+    if (!hasDynamic && inferredSize != actualSize)
+      return emitOpError("result type ")
+             << resultType << "does not match inferred shape "
+             << inferredResultType << " static sizes";
+  }
+
+  return success();
+}
+
+LogicalResult
+ConcatOp::reifyResultShapes(OpBuilder &builder,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  ValueRange inputs = getInputs();
+  int64_t dim = getDim();
+  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
+
+  Value init = inputs[0];
+  int64_t rank = getType().getRank();
+
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
+
+  // Pre-populate the result sizes with as much static information as possible
+  // from the given result type, as well as the inferred result type, otherwise
+  // use the dim sizes from the first input.
+  for (int64_t i = 0; i < rank; ++i) {
+    if (i == dim)
+      continue;
+    if (!getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+    } else if (!inferredResultType.isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] =
+          builder.getIndexAttr(inferredResultType.getDimSize(i));
+    } else {
+      reifiedReturnShapes[0][i] =
+          builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
+    }
+  }
+
+  // Take the sum of the input sizes along the concatenated dim.
+  AffineExpr sum = builder.getAffineDimExpr(0);
+  SmallVector<OpFoldResult> sizes = {
+      builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
+  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+    sum = sum + builder.getAffineDimExpr(idx + 1);
+    sizes.push_back(
+        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+  }
+  reifiedReturnShapes[0][dim] =
+      affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
+
+  // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
+  // returns a Value for dynamic dimensions.
+  for (int64_t i = 0; i < rank; ++i) {
+    if (getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
+          builder, getLoc(), reifiedReturnShapes[0][i]);
+    }
+  }
+  return success();
+}
+
+void ConcatOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "concat");
+}
+
+OpFoldResult ConcatOp::fold(FoldAdaptor) {
+  ValueRange inputs = getInputs();
+  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
+    return inputs[0];
+  return {};
+}
+
+namespace {
+/// Fold a concat op with a single input to a cast.
+struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    if (concatOp.getInputs().size() != 1)
+      return failure();
+    rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+                                        concatOp.getInputs()[0]);
+    return success();
+  }
+};
+} // namespace
+
+void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<SingleInputConcatOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 3cec91389392246..ed274238704713c 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -83,6 +83,11 @@ void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
 
+void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  tensor::populateDecomposeTensorConcatPatterns(patterns);
+}
+
 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
     populatePatterns(RewritePatternSet &patterns) {
   tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c5fd4e65bbf7028..d233ab7a0e89741 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
   FoldIntoPackAndUnpackPatterns.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRAffineTransforms
   MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRIR

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
new file mode 100644
index 000000000000000..2108fc591055a82
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -0,0 +1,93 @@
+//===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts.
+///
+/// %concat = tensor.concat dim(1) %0, %1 :
+///         (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32>
+///
+/// Becomes
+///
+/// %empty = tensor.empty() : tensor<2x7xf32>
+/// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1]
+/// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1]
+struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = concatOp.getLoc();
+    FailureOr<Value> dest =
+        tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+    if (failed(dest))
+      return failure();
+
+    auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+    if (!empty)
+      return failure();
+
+    int64_t dim = concatOp.getDim();
+    Value dimValue = rewriter.createOrFold<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(dim));
+
+    int64_t rank = concatOp.getResultType().getRank();
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+
+    // Compute the partial sums for the slice offsets.
+    AffineExpr sum = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> partialSums = {sum};
+    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
+    for (auto [idx, input] :
+         llvm::enumerate(concatOp.getInputs().drop_back())) {
+      sum = sum + rewriter.getAffineDimExpr(idx + 1);
+      partialSums.push_back(sum);
+      offsetStrides.push_back(
+          rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
+    }
+    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+                                        partialSums, rewriter.getContext());
+    SmallVector<OpFoldResult> dimOffsets =
+        affine::makeComposedFoldedMultiResultAffineApply(
+            rewriter, loc, partialSumMap, offsetStrides);
+
+    // Construct the chain of insert_slice ops into the destination.
+    Value result = *dest;
+    for (auto [input, offset] :
+         llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> sizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      offsets[dim] = offset;
+      result = rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, input, result, offsets, sizes, strides);
+    }
+
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        concatOp, concatOp.getResultType(), result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateDecomposeTensorConcatPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
+}

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 580c1db6070201f..84c44a09aa3dd1c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32>
 
 // -----
 
+// CHECK-LABEL: fold_concat
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
+func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
+  %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
+  // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
+  %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
+  // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+  return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract
 func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
new file mode 100644
index 000000000000000..5712c77a743d71b
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse  %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_concat
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
+
+func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @decompose_dynamic_concat(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+
+//   CHECK-DAG:     %[[C8:.+]] = arith.constant 8 : index
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]]
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
+//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
+//       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+//       CHECK:     return %[[CONCAT]] : tensor<?x?xf32>
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_concat
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
+
+func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
+                            %arg1 : tensor<2xf32>,
+                            %arg2 : tensor<3xf32>,
+                            %arg3: tensor<4xf32>) -> tensor<10xf32> {
+  %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3
+             : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+// CHECK-LABEL: func @decompose_1d_concat
+//       CHECK:    tensor.empty() : tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
+//       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
+//       CHECK:    return %[[CONCAT]] : tensor<10xf32>

diff  --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 389e7e675c0eeda..9b6c2327879cf9b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
 
 // -----
 
+func.func @concat_empty() {
+  // expected-error at +1 {{requires at least one input}}
+  %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) {
+  // expected-error at +1 {{rank of concatenated inputs must match result rank}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
+  return
+}
+
+// -----
+
+func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) {
+  // expected-error at +1 {{concatenation dim must be less than the tensor rank}}
+  %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) {
+  // expected-error at +1 {{inputs and result element type must match}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) {
+  // expected-error at +1 {{static concatenation size mismatch along non-concatenated dimension 1}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32>
+  return
+}
+
+// -----
+
+func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
+  // expected-error at +1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}}
+  %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32>
+  return
+}
+
+// -----
+
 func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
   // expected-error at +1 {{incorrect number of indices for extract_element}}
   %0 = tensor.extract %arg0[] : tensor<?xf32>

diff  --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 71a0489b23f5f2d..2282da38803af0b 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
 
 // -----
 
+// CHECK-LABEL: func @concat(
+func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) {
+  // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @empty(
 //  CHECK-SAME:             %[[sz:.*]]: index
 func.func @empty(%sz: index) -> tensor<5x?x6xf32> {


        


More information about the Mlir-commits mailing list