[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)

Quinn Dawkins llvmlistbot at llvm.org
Fri Dec 1 11:46:23 PST 2023


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/72779

>From bc1766abaffd4d27e93ef879bf85ae2cd35c7d03 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sat, 18 Nov 2023 17:02:40 -0500
Subject: [PATCH 1/5] [mlir][tensor] Add tensor.concat operation

This adds an operation for concatenating ranked tensors along a static
dimension, as well as a decomposition mirroring the existing lowering
from Tosa to Tensor. This offers for "input" like dialects that include
various lowerings for concatenation operations, easing later analysis.
In the future, this op can implement the necessary interfaces for
tiling, as well as potentially add conversions to some kind of linalg
and/or memref counterpart.

See https://discourse.llvm.org/t/rfc-tensor-add-a-tensor-concatenate-operation/74858
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  60 ++++++
 .../Dialect/Tensor/Transforms/Transforms.h    |   5 +
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 191 ++++++++++++++++++
 .../Dialect/Tensor/Transforms/CMakeLists.txt  |   2 +
 .../Tensor/Transforms/ConcatOpPatterns.cpp    |  86 ++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  12 ++
 .../test/Dialect/Tensor/decompose-concat.mlir |  39 ++++
 mlir/test/Dialect/Tensor/invalid.mlir         |  48 +++++
 mlir/test/Dialect/Tensor/ops.mlir             |  17 ++
 .../Dialect/Tensor/TestTensorTransforms.cpp   |  12 ++
 10 files changed, 472 insertions(+)
 create mode 100644 mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
 create mode 100644 mlir/test/Dialect/Tensor/decompose-concat.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7ae27407a9526e7..9101b7fa7c04005 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -121,6 +121,66 @@ def Tensor_CastOp : Tensor_Op<"cast", [
   let hasCanonicalizer = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+def Tensor_ConcatOp : Tensor_Op<"concat",
+    [Pure,
+     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+     DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
+  let summary = "tensor concatenation operation";
+  let description = [{
+    The "concat" operation constructs a tensor out of a variadic list of input
+    tensors, concatenated along a static dimension. All inputs and the result
+    type must share the same rank.
+
+    `dim` specifies the dimension along which to concatenate. The size of the
+    concatenated dimension in the result must be equal to the sum of the sizes
+    of the inputs along that dimension. All other dimensions in both the inputs
+    and result must be the same size.
+
+    Example:
+
+    ```mlir
+    %0 = tensor.concat dim(0) %0, %1, %2 :
+        (tensor<3x6xf32>, tensor<3x6xf32>, tensor<1x6xf32) -> tensor<7x6xf32>
+
+    // Dynamic + dynamic -> static
+    %0 = tensor.concat dim(1) %0, %1, %2 :
+        (tensor<3x?xf32>, tensor<3x2xf32>, tensor<3x?xf32) -> tensor<3x10xf32>
+    ```
+  }];
+  let arguments = (ins I64Attr:$dim,
+                       Variadic<AnyRankedTensor>:$inputs);
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    `dim` `(` $dim `)` $inputs attr-dict
+    `:` functional-type(operands, results)
+  }];
+
+  let builders = [
+    // Builder with an inferred result type.
+    OpBuilder<(ins "int64_t":$dim, "ValueRange":$inputs)>,
+  ];
+
+  let extraClassDeclaration = [{
+    // Helper to infer the concatenated result type for the given list of input
+    // types, being concatenated along `dim`. Because concatenation can specify
+    // more static information than can automatically be inferred,
+    // InferTypeOpInterface is not used.
+    static FailureOr<RankedTensorType> inferResultType(int64_t dim, TypeRange inputTypes);
+
+    RankedTensorType getResultType() {
+      return ::llvm::cast<RankedTensorType>(getResult().getType());
+    }
+  }];
+
+  let hasCanonicalizer = 1;
+  let hasFolder = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 705b30e7ded4779..c1627d20c2b3b6c 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,11 @@ void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
                                      bool foldSingleUseOnly = false);
 
+/// Populates `patterns` with patterns that decompose `tensor.concat` into
+/// `tensor.empty` of a tensor of the concatenated size, followed by a chain
+/// of `tensor.insert_slice` operations on the inputs.
+void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
 /// and `tensor.extract_slice` into `tensor.pack` and `tensor.unpack` operations
 /// respectively.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index cd9b82d2c553fae..41fc5a632f49e2a 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -472,6 +472,197 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
+                                                      TypeRange inputTypes) {
+  if (dim < 0)
+    return failure();
+
+  if (inputTypes.empty())
+    return failure();
+
+  RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
+  if (!init)
+    return failure();
+
+  // The tensor rank must be greater than the concatenation dim.
+  int64_t concatRank = init.getRank();
+  if (concatRank <= dim)
+    return failure();
+
+  SmallVector<int64_t> sizes(init.getShape());
+  Type elementType = init.getElementType();
+  for (Type type : inputTypes.drop_front()) {
+    RankedTensorType tensorType = dyn_cast<RankedTensorType>(type);
+    if (!tensorType || tensorType.getRank() != concatRank ||
+        tensorType.getElementType() != elementType)
+      return failure();
+
+    for (auto [index, currSize] : llvm::enumerate(tensorType.getShape())) {
+      int64_t size = sizes[index];
+      bool hasDynamic =
+          ShapedType::isDynamic(size) || ShapedType::isDynamic(currSize);
+      if (static_cast<int64_t>(index) == dim) {
+        sizes[index] = hasDynamic ? ShapedType::kDynamic : currSize + size;
+        continue;
+      }
+
+      // If the sizes are statically different for a dimension other than the
+      // concated dimension, the concatenation is invalid. Both dynamic or
+      // mixed dynamic and static is fine.
+      if (currSize != size && !hasDynamic)
+        return failure();
+
+      // If the new size is not dynamic, use the additional static information.
+      if (!ShapedType::isDynamic(currSize))
+        sizes[index] = currSize;
+    }
+  }
+
+  return RankedTensorType::get(sizes, elementType);
+}
+
+void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
+                     ValueRange inputs) {
+  FailureOr<RankedTensorType> resultType =
+      inferResultType(dim, inputs.getTypes());
+  assert(succeeded(resultType) && "failed to infer concatenation result type");
+  build(builder, result, *resultType, dim, inputs);
+}
+
+LogicalResult ConcatOp::verify() {
+  if (getInputs().size() < 1)
+    return emitOpError("requires at least one input");
+
+  SmallVector<RankedTensorType> inputTypes;
+  for (auto input : getInputs())
+    inputTypes.push_back(cast<RankedTensorType>(input.getType()));
+
+  RankedTensorType resultType = getResultType();
+
+  int64_t resultRank = resultType.getRank();
+  if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
+        return type.getRank() != resultRank;
+      }))
+    return emitOpError("rank of concatenated inputs must match result rank");
+
+  Type resultElementType = resultType.getElementType();
+  if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
+        return type.getElementType() != resultElementType;
+      }))
+    return emitOpError("inputs and result element type must match");
+
+  if (static_cast<int64_t>(getDim()) >= resultRank)
+    return emitOpError("concatenation dim must be less than the tensor rank");
+
+  FailureOr<RankedTensorType> inferredResultType =
+      inferResultType(getDim(), getInputs().getTypes());
+  if (failed(inferredResultType))
+    return emitOpError("failed to infer concatenation result type from inputs");
+
+  for (auto [inferredSize, actualSize] :
+       llvm::zip_equal(inferredResultType->getShape(), resultType.getShape())) {
+    bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
+                      ShapedType::isDynamic(actualSize);
+    if (!hasDynamic && inferredSize != actualSize)
+      return emitOpError("result type ")
+             << resultType << "does not match inferred shape "
+             << *inferredResultType << " static sizes";
+  }
+
+  return success();
+}
+
+LogicalResult
+ConcatOp::reifyResultShapes(OpBuilder &builder,
+                            ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+  ValueRange inputs = getInputs();
+  int64_t dim = getDim();
+  FailureOr<RankedTensorType> maybeInferredResultType =
+      inferResultType(dim, inputs.getTypes());
+  if (failed(maybeInferredResultType))
+    return failure();
+  RankedTensorType inferredResultType = *maybeInferredResultType;
+
+  Value init = inputs[0];
+  int64_t rank = getType().getRank();
+
+  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
+
+  // Pre-populate the result sizes with as much static information as possible
+  // from the given result type, as well as the inferred result type, otherwise
+  // use the dim sizes from the first input.
+  bool hasStaticConcatDim = false;
+  for (int64_t i = 0; i < rank; ++i) {
+    if (!getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
+      if (i == dim)
+        hasStaticConcatDim = true;
+    } else if (!inferredResultType.isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] =
+          builder.getIndexAttr(inferredResultType.getDimSize(i));
+      if (i == dim)
+        hasStaticConcatDim = true;
+    } else {
+      reifiedReturnShapes[0][i] =
+          tensor::getMixedSize(builder, init.getLoc(), init, i);
+    }
+  }
+
+  // Check if we already know the size of the concatenation dim statically.
+  if (hasStaticConcatDim)
+    return success();
+
+  // Take the sum of the input sizes along the concatenated dim.
+  Value concatValue = getValueOrCreateConstantIndexOp(
+      builder, getLoc(), reifiedReturnShapes[0][dim]);
+
+  for (Value input : inputs.drop_front()) {
+    Value newSize =
+        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim);
+    concatValue =
+        builder.create<arith::AddIOp>(input.getLoc(), concatValue, newSize);
+  }
+  reifiedReturnShapes[0][dim] = concatValue;
+  return success();
+}
+
+void ConcatOp::getAsmResultNames(
+    function_ref<void(Value, StringRef)> setNameFn) {
+  setNameFn(getResult(), "concat");
+}
+
+OpFoldResult ConcatOp::fold(FoldAdaptor) {
+  ValueRange inputs = getInputs();
+  if (inputs.size() == 1 && inputs[0].getType() == getResultType())
+    return inputs[0];
+  return {};
+}
+
+namespace {
+/// Fold a concat op with a single input to a cast.
+struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+    if (concatOp.getInputs().size() != 1)
+      return failure();
+    rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
+                                        concatOp.getInputs()[0]);
+    return success();
+  }
+};
+} // namespace
+
+void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                           MLIRContext *context) {
+  results.add<SingleInputConcatOp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // DimOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index c5fd4e65bbf7028..d233ab7a0e89741 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRTensorTransforms
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  ConcatOpPatterns.cpp
   EmptyOpPatterns.cpp
   ExtractSliceFromReshapeUtils.cpp
   FoldIntoPackAndUnpackPatterns.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   MLIRAffineTransforms
   MLIRAffineUtils
   MLIRArithDialect
+  MLIRArithUtils
   MLIRBufferizationDialect
   MLIRBufferizationTransforms
   MLIRIR
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
new file mode 100644
index 000000000000000..6fa34e5dd2b3f3e
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -0,0 +1,86 @@
+//===- ConcatOpPatterns.cpp - Patterns related to tensor.concat lowering --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+/// Decompose `tensor.concat` into `tensor.empty` and a chain of slice inserts.
+///
+/// %concat = tensor.concat dim(1) %0, %1 :
+///         (tensor<2x3xf32>, tensor<2x4xf32>) -> tensor<2x7xf32>
+///
+/// Becomes
+///
+/// %empty = tensor.empty() : tensor<2x7xf32>
+/// %insert0 = tensor.insert_slice %0 into %empty[0, 0][2, 3][1, 1]
+/// %concat = tensor.insert_slice %1 into %insert0[0, 3][2, 4][1, 1]
+struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
+  using OpRewritePattern<ConcatOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ConcatOp concatOp,
+                                PatternRewriter &rewriter) const override {
+
+    Location loc = concatOp.getLoc();
+    FailureOr<Value> dest =
+        tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
+    if (failed(dest))
+      return failure();
+
+    auto empty = dest->getDefiningOp<tensor::EmptyOp>();
+    if (!empty)
+      return failure();
+
+    int64_t dim = concatOp.getDim();
+    Value dimValue = rewriter.createOrFold<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(dim));
+
+    int64_t rank = concatOp.getResultType().getRank();
+    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
+    SmallVector<OpFoldResult> dimOffsets;
+    dimOffsets.push_back(rewriter.getIndexAttr(0));
+
+    for (auto input : concatOp.getInputs().drop_front()) {
+      Value size = rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue);
+      Value currentOffset =
+          getValueOrCreateConstantIndexOp(rewriter, loc, dimOffsets.back());
+      Value total =
+          rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
+      dimOffsets.push_back(getAsOpFoldResult(total));
+    }
+
+    Value result = *dest;
+
+    for (auto [input, offset] : llvm::zip(concatOp.getInputs(), dimOffsets)) {
+      SmallVector<OpFoldResult> sizes =
+          tensor::getMixedSizes(rewriter, loc, input);
+      offsets[dim] = offset;
+      result = rewriter.createOrFold<tensor::InsertSliceOp>(
+          loc, input, result, offsets, sizes, strides);
+    }
+
+    rewriter.replaceOpWithNewOp<tensor::CastOp>(
+        concatOp, concatOp.getResultType(), result);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateDecomposeTensorConcatPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<DecomposeTensorConcatOp>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 580c1db6070201f..84c44a09aa3dd1c 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -87,6 +87,18 @@ func.func @tensor.cast_chain_invalid(%input: tensor<4x8xi32>) -> tensor<8x4xi32>
 
 // -----
 
+// CHECK-LABEL: fold_concat
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x2x?xi32>
+func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1x2x?xi32>) {
+  %0 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x3xi32>
+  // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ARG0]] : tensor<1x2x?xi32> to tensor<1x2x3xi32>
+  %1 = tensor.concat dim(2) %arg0 : (tensor<1x2x?xi32>) -> tensor<1x2x?xi32>
+  // CHECK-NEXT: return %[[CAST]], %[[ARG0]] : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+  return %0, %1 : tensor<1x2x3xi32>, tensor<1x2x?xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @fold_extract
 func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
   %const_0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
new file mode 100644
index 000000000000000..b4f6a25f334e3b1
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-decompose-concat -cse  %s | FileCheck %s
+
+func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: func @decompose_dynamic_concat(
+//  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+
+//   CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
+//   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT_SIZE:.+]] = arith.addi %[[DIM]], %[[C4]] : index
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
+//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
+//       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, %[[DIM]]] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
+//       CHECK:     %[[CAST_BACK:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
+//       CHECK:     return %[[CAST_BACK]] : tensor<?x?xf32>
+
+// -----
+
+func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
+                            %arg1 : tensor<2xf32>,
+                            %arg2 : tensor<3xf32>,
+                            %arg3: tensor<4xf32>) -> tensor<10xf32> {
+  %0 = tensor.concat dim(0) %arg0, %arg1, %arg2, %arg3
+             : (tensor<1xf32>, tensor<2xf32>, tensor<3xf32>, tensor<4xf32>) -> tensor<10xf32>
+  return %0 : tensor<10xf32>
+}
+// CHECK-LABEL: func @decompose_1d_concat
+//       CHECK:    tensor.empty() : tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[2] [2] [1] : tensor<2xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[5] [3] [1] : tensor<3xf32> into tensor<10xf32>
+//       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[9] [4] [1] : tensor<4xf32> into tensor<10xf32>
+//       CHECK:    return %[[CONCAT]] : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 389e7e675c0eeda..426372f0ba4e746 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -16,6 +16,54 @@ func.func @tensor.cast_mismatching_constants(%arg0: tensor<1xf32>) {
 
 // -----
 
+func.func @concat_empty() {
+  // expected-error at +1 {{requires at least one input}}
+  %0 = tensor.concat dim(0) : () -> tensor<1x2x3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_rank_mismatch(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) {
+  // expected-error at +1 {{rank of concatenated inputs must match result rank}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<2x1xf32>
+  return
+}
+
+// -----
+
+func.func @concat_dim_out_of_range(%arg0: tensor<3xf32>) {
+  // expected-error at +1 {{concatenation dim must be less than the tensor rank}}
+  %0 = tensor.concat dim(1) %arg0 : (tensor<3xf32>) -> tensor<3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi32>) {
+  // expected-error at +1 {{inputs and result element type must match}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3xf32>, tensor<3xi32>) -> tensor<3xf32>
+  return
+}
+
+// -----
+
+func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) {
+  // expected-error at +1 {{failed to infer concatenation result type from inputs}}
+  %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32>
+  return
+}
+
+// -----
+
+func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
+  // expected-error at +1 {{result type 'tensor<7xf32>'does not match inferred shape 'tensor<6xf32>' static sizes}}
+  %0 = tensor.concat dim(0) %arg0, %arg0 : (tensor<3xf32>, tensor<3xf32>) -> tensor<7xf32>
+  return
+}
+
+// -----
+
 func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
   // expected-error at +1 {{incorrect number of indices for extract_element}}
   %0 = tensor.extract %arg0[] : tensor<?xf32>
diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir
index 71a0489b23f5f2d..2282da38803af0b 100644
--- a/mlir/test/Dialect/Tensor/ops.mlir
+++ b/mlir/test/Dialect/Tensor/ops.mlir
@@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
 
 // -----
 
+// CHECK-LABEL: func @concat(
+func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) {
+  // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @empty(
 //  CHECK-SAME:             %[[sz:.*]]: index
 func.func @empty(%sz: index) -> tensor<5x?x6xf32> {
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 3e142155df8d9b8..bcddfd313ef0484 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,6 +77,10 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
       llvm::cl::init(false)};
 
+  Option<bool> testDecomposeConcat{
+      *this, "test-decompose-concat",
+      llvm::cl::desc("Test decomposing tensor.concat"), llvm::cl::init(false)};
+
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -108,6 +112,12 @@ static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyDecomposeConcatPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateDecomposeTensorConcatPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::ControlConstantExtractSliceFusionFn controlFn =
@@ -388,6 +398,8 @@ void TestTensorTransforms::runOnOperation() {
     applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testFoldIntoPackAndUnpack)
     applyFoldIntoPackAndUnpackPatterns(rootOp);
+  if (testDecomposeConcat)
+    applyDecomposeConcatPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))

>From d5312462ce58cc720c32dc34d2f1dd86a99723a3 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Sat, 18 Nov 2023 23:57:18 -0500
Subject: [PATCH 2/5] Fix reifyResultShapes implementation

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp       | 13 +++++++++----
 mlir/test/Dialect/Tensor/decompose-concat.mlir | 10 +++++-----
 2 files changed, 14 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 41fc5a632f49e2a..3a777f32e09a317 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -602,13 +602,18 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
       if (i == dim)
         hasStaticConcatDim = true;
     } else if (!inferredResultType.isDynamicDim(i)) {
+      // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
+      // returns a Value for dynamic dimensions.
       reifiedReturnShapes[0][i] =
-          builder.getIndexAttr(inferredResultType.getDimSize(i));
+          builder
+              .create<arith::ConstantIndexOp>(getLoc(),
+                                              inferredResultType.getDimSize(i))
+              .getResult();
       if (i == dim)
         hasStaticConcatDim = true;
     } else {
       reifiedReturnShapes[0][i] =
-          tensor::getMixedSize(builder, init.getLoc(), init, i);
+          builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
     }
   }
 
@@ -623,8 +628,8 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
   for (Value input : inputs.drop_front()) {
     Value newSize =
         builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim);
-    concatValue =
-        builder.create<arith::AddIOp>(input.getLoc(), concatValue, newSize);
+    concatValue = builder.createOrFold<arith::AddIOp>(input.getLoc(),
+                                                      concatValue, newSize);
   }
   reifiedReturnShapes[0][dim] = concatValue;
   return success();
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index b4f6a25f334e3b1..dd7b9cd5f1490ae 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -8,17 +8,17 @@ func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?x
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 
+//   CHECK-DAG:     %[[C8:.+]] = arith.constant 8 : index
 //   CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
 //       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
 //       CHECK:     %[[CONCAT_SIZE:.+]] = arith.addi %[[DIM]], %[[C4]] : index
-//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[CONCAT_SIZE]]) : tensor<8x?xf32>
-//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<8x?xf32>
+//       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
+//       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
 //       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, %[[DIM]]] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<8x?xf32>
-//       CHECK:     %[[CAST_BACK:.+]] = tensor.cast %[[CONCAT]] : tensor<8x?xf32> to tensor<?x?xf32>
-//       CHECK:     return %[[CAST_BACK]] : tensor<?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, %[[DIM]]] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+//       CHECK:     return %[[CONCAT]] : tensor<?x?xf32>
 
 // -----
 

>From df7b7a4e7990e6e8917427076029e011f5d3c11e Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Wed, 29 Nov 2023 23:55:15 -0500
Subject: [PATCH 3/5] Rebase and address comments

- Simplify various implementations
- Make inferResultType always return RankedTensorType
- Import saturated_arith and reuse for concat
- Clarify reason for decomposing
- Use Affine folding helpers
- Switch test to a transform op
---
 .../mlir/Dialect/Tensor/IR/TensorOps.td       |  10 +-
 .../Tensor/TransformOps/TensorTransformOps.td |  12 ++
 .../Dialect/Tensor/Transforms/Transforms.h    |   4 +-
 .../mlir/Dialect/Utils/StaticValueUtils.h     |  33 ++++
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  76 ++-------
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp      | 158 ++++++++----------
 .../TransformOps/TensorTransformOps.cpp       |   5 +
 .../Tensor/Transforms/ConcatOpPatterns.cpp    |  33 ++--
 .../test/Dialect/Tensor/decompose-concat.mlir |  32 +++-
 mlir/test/Dialect/Tensor/invalid.mlir         |   2 +-
 .../Dialect/Tensor/TestTensorTransforms.cpp   |  12 --
 11 files changed, 198 insertions(+), 179 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 9101b7fa7c04005..f50e3464867be50 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -132,8 +132,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
   let summary = "tensor concatenation operation";
   let description = [{
     The "concat" operation constructs a tensor out of a variadic list of input
-    tensors, concatenated along a static dimension. All inputs and the result
-    type must share the same rank.
+    tensors, concatenated along a static dimension number. All inputs and the
+    result type must share the same rank.
 
     `dim` specifies the dimension along which to concatenate. The size of the
     concatenated dimension in the result must be equal to the sum of the sizes
@@ -169,11 +169,15 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
     // types, being concatenated along `dim`. Because concatenation can specify
     // more static information than can automatically be inferred,
     // InferTypeOpInterface is not used.
-    static FailureOr<RankedTensorType> inferResultType(int64_t dim, TypeRange inputTypes);
+    static RankedTensorType inferResultType(int64_t dim, TypeRange inputTypes);
 
     RankedTensorType getResultType() {
       return ::llvm::cast<RankedTensorType>(getResult().getType());
     }
+
+    int64_t getRank() {
+      return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
+    }
   }];
 
   let hasCanonicalizer = 1;
diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
index 66c6021418b471c..8556d9570fd1200 100644
--- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td
@@ -15,6 +15,18 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def ApplyDecomposeTensorConcatPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.tensor.decompose_concat",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that tensor.concat ops should be decomposed into a chain of
+    tensor.insert_slice operations inserting into a materialized destination.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+
 def ApplyDropRedundantInsertSliceRankExpansionPatternsOp : Op<Transform_Dialect,
     "apply_patterns.tensor.drop_redundant_insert_slice_rank_expansion",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index c1627d20c2b3b6c..44b8377bd6aad99 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -69,7 +69,9 @@ void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns,
 
 /// Populates `patterns` with patterns that decompose `tensor.concat` into
 /// `tensor.empty` of a tensor of the concatenated size, followed by a chain
-/// of `tensor.insert_slice` operations on the inputs.
+/// of `tensor.insert_slice` operations on the inputs. This is intended to be
+/// used as a fallback tensor -> tensor lowering that decomposes concat such
+/// that it can be bufferized into a sequence of copies.
 void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
 
 /// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index c2fbaea726abcbb..502ab93ddbfa7d7 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -151,6 +151,39 @@ LogicalResult foldDynamicIndexList(SmallVectorImpl<OpFoldResult> &ofrs,
 std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
                                          OpFoldResult step);
 
+/// Idiomatic saturated operations on values like offsets, sizes, and strides.
+struct SaturatedInteger {
+  static SaturatedInteger wrap(int64_t v) {
+    return (ShapedType::isDynamic(v)) ? SaturatedInteger{true, 0}
+                                      : SaturatedInteger{false, v};
+  }
+  int64_t asInteger() { return saturated ? ShapedType::kDynamic : v; }
+  FailureOr<SaturatedInteger> desaturate(SaturatedInteger other) {
+    if (saturated && !other.saturated)
+      return other;
+    if (!saturated && !other.saturated && v != other.v)
+      return failure();
+    return *this;
+  }
+  bool operator==(SaturatedInteger other) {
+    return (saturated && other.saturated) ||
+           (!saturated && !other.saturated && v == other.v);
+  }
+  bool operator!=(SaturatedInteger other) { return !(*this == other); }
+  SaturatedInteger operator+(SaturatedInteger other) {
+    if (saturated || other.saturated)
+      return SaturatedInteger{true, 0};
+    return SaturatedInteger{false, other.v + v};
+  }
+  SaturatedInteger operator*(SaturatedInteger other) {
+    if (saturated || other.saturated)
+      return SaturatedInteger{true, 0};
+    return SaturatedInteger{false, other.v * v};
+  }
+  bool saturated = true;
+  int64_t v = 0;
+};
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a2fc954ad07fae8..dce96cca016ff8e 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -26,43 +26,6 @@
 using namespace mlir;
 using namespace mlir::memref;
 
-namespace {
-/// Idiomatic saturated operations on offsets, sizes and strides.
-namespace saturated_arith {
-struct Wrapper {
-  static Wrapper stride(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
-  }
-  static Wrapper offset(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
-  }
-  static Wrapper size(int64_t v) {
-    return (ShapedType::isDynamic(v)) ? Wrapper{true, 0} : Wrapper{false, v};
-  }
-  int64_t asOffset() { return saturated ? ShapedType::kDynamic : v; }
-  int64_t asSize() { return saturated ? ShapedType::kDynamic : v; }
-  int64_t asStride() { return saturated ? ShapedType::kDynamic : v; }
-  bool operator==(Wrapper other) {
-    return (saturated && other.saturated) ||
-           (!saturated && !other.saturated && v == other.v);
-  }
-  bool operator!=(Wrapper other) { return !(*this == other); }
-  Wrapper operator+(Wrapper other) {
-    if (saturated || other.saturated)
-      return Wrapper{true, 0};
-    return Wrapper{false, other.v + v};
-  }
-  Wrapper operator*(Wrapper other) {
-    if (saturated || other.saturated)
-      return Wrapper{true, 0};
-    return Wrapper{false, other.v * v};
-  }
-  bool saturated;
-  int64_t v;
-};
-} // namespace saturated_arith
-} // namespace
-
 /// Materialize a single constant operation from a given attribute value with
 /// the desired resultant type.
 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
@@ -2208,11 +2171,11 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
     ReassociationIndices reassoc = std::get<0>(it);
     int64_t currentStrideToExpand = std::get<1>(it);
     for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
-      using saturated_arith::Wrapper;
       reverseResultStrides.push_back(currentStrideToExpand);
-      currentStrideToExpand = (Wrapper::stride(currentStrideToExpand) *
-                               Wrapper::size(resultShape[shapeIndex--]))
-                                  .asStride();
+      currentStrideToExpand =
+          (SaturatedInteger::wrap(currentStrideToExpand) *
+           SaturatedInteger::wrap(resultShape[shapeIndex--]))
+              .asInteger();
     }
   }
   auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
@@ -2332,10 +2295,9 @@ computeCollapsedLayoutMap(MemRefType srcType,
   unsigned resultStrideIndex = resultStrides.size() - 1;
   for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
     auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
-    using saturated_arith::Wrapper;
-    auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
+    auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
     for (int64_t idx : llvm::reverse(trailingReassocs)) {
-      stride = stride * Wrapper::size(srcShape[idx]);
+      stride = stride * SaturatedInteger::wrap(srcShape[idx]);
 
       // Both source and result stride must have the same static value. In that
       // case, we can be sure, that the dimensions are collapsible (because they
@@ -2345,7 +2307,7 @@ computeCollapsedLayoutMap(MemRefType srcType,
       // ops where obviously non-contiguous dims are collapsed, but accept ops
       // where we cannot be sure statically. Such ops may fail at runtime. See
       // the op documentation for details.
-      auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
+      auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
       if (strict && (stride.saturated || srcStride.saturated))
         return failure();
 
@@ -2371,11 +2333,11 @@ MemRefType CollapseShapeOp::computeCollapsedType(
   SmallVector<int64_t> resultShape;
   resultShape.reserve(reassociation.size());
   for (const ReassociationIndices &group : reassociation) {
-    using saturated_arith::Wrapper;
-    auto groupSize = Wrapper::size(1);
+    auto groupSize = SaturatedInteger::wrap(1);
     for (int64_t srcDim : group)
-      groupSize = groupSize * Wrapper::size(srcType.getDimSize(srcDim));
-    resultShape.push_back(groupSize.asSize());
+      groupSize =
+          groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
+    resultShape.push_back(groupSize.asInteger());
   }
 
   if (srcType.getLayout().isIdentity()) {
@@ -2586,11 +2548,10 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   int64_t targetOffset = sourceOffset;
   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
-    using saturated_arith::Wrapper;
-    targetOffset =
-        (Wrapper::offset(targetOffset) +
-         Wrapper::offset(staticOffset) * Wrapper::stride(targetStride))
-            .asOffset();
+    targetOffset = (SaturatedInteger::wrap(targetOffset) +
+                    SaturatedInteger::wrap(staticOffset) *
+                        SaturatedInteger::wrap(targetStride))
+                       .asInteger();
   }
 
   // Compute target stride whose value is:
@@ -2599,10 +2560,9 @@ Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
   targetStrides.reserve(staticOffsets.size());
   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
-    using saturated_arith::Wrapper;
-    targetStrides.push_back(
-        (Wrapper::stride(sourceStride) * Wrapper::stride(staticStride))
-            .asStride());
+    targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
+                             SaturatedInteger::wrap(staticStride))
+                                .asInteger());
   }
 
   // The type is now known.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 3a777f32e09a317..27b54b043becf0d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -476,53 +476,32 @@ void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // ConcatOp
 //===----------------------------------------------------------------------===//
 
-FailureOr<RankedTensorType> ConcatOp::inferResultType(int64_t dim,
-                                                      TypeRange inputTypes) {
-  if (dim < 0)
-    return failure();
-
-  if (inputTypes.empty())
-    return failure();
-
-  RankedTensorType init = dyn_cast<RankedTensorType>(inputTypes[0]);
-  if (!init)
-    return failure();
-
-  // The tensor rank must be greater than the concatenation dim.
-  int64_t concatRank = init.getRank();
-  if (concatRank <= dim)
-    return failure();
-
-  SmallVector<int64_t> sizes(init.getShape());
-  Type elementType = init.getElementType();
-  for (Type type : inputTypes.drop_front()) {
-    RankedTensorType tensorType = dyn_cast<RankedTensorType>(type);
-    if (!tensorType || tensorType.getRank() != concatRank ||
-        tensorType.getElementType() != elementType)
-      return failure();
-
-    for (auto [index, currSize] : llvm::enumerate(tensorType.getShape())) {
-      int64_t size = sizes[index];
-      bool hasDynamic =
-          ShapedType::isDynamic(size) || ShapedType::isDynamic(currSize);
-      if (static_cast<int64_t>(index) == dim) {
-        sizes[index] = hasDynamic ? ShapedType::kDynamic : currSize + size;
-        continue;
-      }
-
-      // If the sizes are statically different for a dimension other than the
-      // concated dimension, the concatenation is invalid. Both dynamic or
-      // mixed dynamic and static is fine.
-      if (currSize != size && !hasDynamic)
-        return failure();
-
-      // If the new size is not dynamic, use the additional static information.
-      if (!ShapedType::isDynamic(currSize))
-        sizes[index] = currSize;
-    }
+RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
+  assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
+  auto tensorTypes =
+      llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
+        return llvm::cast<RankedTensorType>(type);
+      }));
+  int64_t concatRank = tensorTypes[0].getRank();
+
+  // The concatenation dim must be in the range [0, rank).
+  assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
+
+  SmallVector<int64_t> sizes((concatRank));
+  for (int64_t i = 0, e = concatRank; i < e; ++i) {
+    if (i == dim)
+      continue;
+    SaturatedInteger size;
+    for (auto tensorType : tensorTypes)
+      size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+    sizes[i] = size.asInteger();
   }
-
-  return RankedTensorType::get(sizes, elementType);
+  auto concatSize = SaturatedInteger::wrap(0);
+  for (auto tensorType : tensorTypes)
+    concatSize =
+        concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+  sizes[dim] = concatSize.asInteger();
+  return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
 }
 
 void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
@@ -542,8 +521,7 @@ LogicalResult ConcatOp::verify() {
     inputTypes.push_back(cast<RankedTensorType>(input.getType()));
 
   RankedTensorType resultType = getResultType();
-
-  int64_t resultRank = resultType.getRank();
+  int64_t resultRank = getRank();
   if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
         return type.getRank() != resultRank;
       }))
@@ -555,22 +533,41 @@ LogicalResult ConcatOp::verify() {
       }))
     return emitOpError("inputs and result element type must match");
 
-  if (static_cast<int64_t>(getDim()) >= resultRank)
+  int64_t dim = getDim();
+  if (dim >= resultRank)
     return emitOpError("concatenation dim must be less than the tensor rank");
 
-  FailureOr<RankedTensorType> inferredResultType =
-      inferResultType(getDim(), getInputs().getTypes());
-  if (failed(inferredResultType))
-    return emitOpError("failed to infer concatenation result type from inputs");
+  SmallVector<int64_t> sizes((resultRank));
+  for (int64_t i = 0, e = resultRank; i < e; ++i) {
+    if (i == dim)
+      continue;
+    SaturatedInteger size;
+    for (auto tensorType : inputTypes) {
+      FailureOr<SaturatedInteger> maybeSize =
+          size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
+      if (failed(maybeSize))
+        return emitOpError("static concatenation size mismatch along ")
+               << "non-concatenated dimension " << i;
+      size = *maybeSize;
+    }
+    sizes[i] = size.asInteger();
+  }
+  auto concatSize = SaturatedInteger::wrap(0);
+  for (auto tensorType : inputTypes)
+    concatSize =
+        concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
+  sizes[dim] = concatSize.asInteger();
+  auto inferredResultType =
+      RankedTensorType::get(sizes, inputTypes[0].getElementType());
 
   for (auto [inferredSize, actualSize] :
-       llvm::zip_equal(inferredResultType->getShape(), resultType.getShape())) {
+       llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
     bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
                       ShapedType::isDynamic(actualSize);
     if (!hasDynamic && inferredSize != actualSize)
       return emitOpError("result type ")
              << resultType << "does not match inferred shape "
-             << *inferredResultType << " static sizes";
+             << inferredResultType << " static sizes";
   }
 
   return success();
@@ -581,11 +578,7 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
   ValueRange inputs = getInputs();
   int64_t dim = getDim();
-  FailureOr<RankedTensorType> maybeInferredResultType =
-      inferResultType(dim, inputs.getTypes());
-  if (failed(maybeInferredResultType))
-    return failure();
-  RankedTensorType inferredResultType = *maybeInferredResultType;
+  RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
 
   Value init = inputs[0];
   int64_t rank = getType().getRank();
@@ -595,43 +588,40 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
   // Pre-populate the result sizes with as much static information as possible
   // from the given result type, as well as the inferred result type, otherwise
   // use the dim sizes from the first input.
-  bool hasStaticConcatDim = false;
   for (int64_t i = 0; i < rank; ++i) {
+    if (i == dim)
+      continue;
     if (!getType().isDynamicDim(i)) {
       reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
-      if (i == dim)
-        hasStaticConcatDim = true;
     } else if (!inferredResultType.isDynamicDim(i)) {
-      // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
-      // returns a Value for dynamic dimensions.
       reifiedReturnShapes[0][i] =
-          builder
-              .create<arith::ConstantIndexOp>(getLoc(),
-                                              inferredResultType.getDimSize(i))
-              .getResult();
-      if (i == dim)
-        hasStaticConcatDim = true;
+          builder.getIndexAttr(inferredResultType.getDimSize(i));
     } else {
       reifiedReturnShapes[0][i] =
           builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
     }
   }
 
-  // Check if we already know the size of the concatenation dim statically.
-  if (hasStaticConcatDim)
-    return success();
-
   // Take the sum of the input sizes along the concatenated dim.
-  Value concatValue = getValueOrCreateConstantIndexOp(
-      builder, getLoc(), reifiedReturnShapes[0][dim]);
-
-  for (Value input : inputs.drop_front()) {
-    Value newSize =
-        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim);
-    concatValue = builder.createOrFold<arith::AddIOp>(input.getLoc(),
-                                                      concatValue, newSize);
+  AffineExpr sum = builder.getAffineDimExpr(0);
+  SmallVector<OpFoldResult> sizes{
+      builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
+  for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
+    sum = sum + builder.getAffineDimExpr(idx + 1);
+    sizes.push_back(
+        builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
+  }
+  reifiedReturnShapes[0][dim] =
+      affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes);
+
+  // ReifyRankedShapedTypeOpInterface requires that reifyResultShapes
+  // returns a Value for dynamic dimensions.
+  for (int64_t i = 0; i < rank; ++i) {
+    if (getType().isDynamicDim(i)) {
+      reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
+          builder, getLoc(), reifiedReturnShapes[0][i]);
+    }
   }
-  reifiedReturnShapes[0][dim] = concatValue;
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
index 3cec91389392246..ed274238704713c 100644
--- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
+++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp
@@ -83,6 +83,11 @@ void tensor::registerFindPayloadReplacementOpInterfaceExternalModels(
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
 
+void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  tensor::populateDecomposeTensorConcatPatterns(patterns);
+}
+
 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
     populatePatterns(RewritePatternSet &patterns) {
   tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 6fa34e5dd2b3f3e..c5cb30c3ad13eb8 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -5,7 +5,8 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -32,7 +33,6 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
 
   LogicalResult matchAndRewrite(ConcatOp concatOp,
                                 PatternRewriter &rewriter) const override {
-
     Location loc = concatOp.getLoc();
     FailureOr<Value> dest =
         tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
@@ -50,21 +50,28 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
     int64_t rank = concatOp.getResultType().getRank();
     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-    SmallVector<OpFoldResult> dimOffsets;
-    dimOffsets.push_back(rewriter.getIndexAttr(0));
 
-    for (auto input : concatOp.getInputs().drop_front()) {
-      Value size = rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue);
-      Value currentOffset =
-          getValueOrCreateConstantIndexOp(rewriter, loc, dimOffsets.back());
-      Value total =
-          rewriter.createOrFold<arith::AddIOp>(loc, currentOffset, size);
-      dimOffsets.push_back(getAsOpFoldResult(total));
+    // Compute the partial sums for the slice offsets.
+    AffineExpr sum = rewriter.getAffineDimExpr(0);
+    SmallVector<AffineExpr> partialSums{sum};
+    SmallVector<OpFoldResult> offsetStrides{rewriter.getIndexAttr(0)};
+    for (auto [idx, input] :
+         llvm::enumerate(concatOp.getInputs().drop_back())) {
+      sum = sum + rewriter.getAffineDimExpr(idx + 1);
+      partialSums.push_back(sum);
+      offsetStrides.push_back(
+          rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
     }
+    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
+                                        partialSums, rewriter.getContext());
+    SmallVector<OpFoldResult> dimOffsets =
+        affine::makeComposedFoldedMultiResultAffineApply(
+            rewriter, loc, partialSumMap, offsetStrides);
 
+    // Construct the chain of insert_slice ops into the destination.
     Value result = *dest;
-
-    for (auto [input, offset] : llvm::zip(concatOp.getInputs(), dimOffsets)) {
+    for (auto [input, offset] :
+         llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
       SmallVector<OpFoldResult> sizes =
           tensor::getMixedSizes(rewriter, loc, input);
       offsets[dim] = offset;
diff --git a/mlir/test/Dialect/Tensor/decompose-concat.mlir b/mlir/test/Dialect/Tensor/decompose-concat.mlir
index dd7b9cd5f1490ae..5712c77a743d71b 100644
--- a/mlir/test/Dialect/Tensor/decompose-concat.mlir
+++ b/mlir/test/Dialect/Tensor/decompose-concat.mlir
@@ -1,27 +1,45 @@
-// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-decompose-concat -cse  %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -transform-interpreter -cse  %s | FileCheck %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_concat
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
 
 func.func @decompose_dynamic_concat(%arg0 : tensor<8x4xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<8x4xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
   return %0 : tensor<?x?xf32>
 }
+//   CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
 // CHECK-LABEL: func @decompose_dynamic_concat(
 //  CHECK-SAME:     %[[ARG0:.+]]: tensor<8x4xf32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
 
 //   CHECK-DAG:     %[[C8:.+]] = arith.constant 8 : index
-//   CHECK-DAG:     %[[C4:.+]] = arith.constant 4 : index
 //   CHECK-DAG:     %[[C1:.+]] = arith.constant 1 : index
 //   CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
 //       CHECK:     %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//       CHECK:     %[[CONCAT_SIZE:.+]] = arith.addi %[[DIM]], %[[C4]] : index
+//       CHECK:     %[[CONCAT_SIZE:.+]] = affine.apply #[[$MAP]]()[%[[C8]], %[[DIM]]]
 //       CHECK:     %[[EMPTY:.+]] = tensor.empty(%[[C8]], %[[CONCAT_SIZE]]) : tensor<?x?xf32>
 //       CHECK:     %[[SLICE0:.+]] = tensor.insert_slice %[[ARG0]] into %[[EMPTY]][0, 0] [8, 4] [1, 1] : tensor<8x4xf32> into tensor<?x?xf32>
 //       CHECK:     %[[OFFSET:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, %[[DIM]]] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+//       CHECK:     %[[CONCAT:.+]] = tensor.insert_slice %[[ARG1]] into %[[SLICE0]][0, 4] [%[[OFFSET]], %[[DIM]]] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 //       CHECK:     return %[[CONCAT]] : tensor<?x?xf32>
 
 // -----
 
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.tensor.decompose_concat
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
+
 func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
                             %arg1 : tensor<2xf32>,
                             %arg2 : tensor<3xf32>,
@@ -33,7 +51,7 @@ func.func @decompose_1d_concat(%arg0 : tensor<1xf32>,
 // CHECK-LABEL: func @decompose_1d_concat
 //       CHECK:    tensor.empty() : tensor<10xf32>
 //       CHECK:    tensor.insert_slice %{{.*}}[0] [1] [1] : tensor<1xf32> into tensor<10xf32>
-//       CHECK:    tensor.insert_slice %{{.*}}[2] [2] [1] : tensor<2xf32> into tensor<10xf32>
-//       CHECK:    tensor.insert_slice %{{.*}}[5] [3] [1] : tensor<3xf32> into tensor<10xf32>
-//       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[9] [4] [1] : tensor<4xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[1] [2] [1] : tensor<2xf32> into tensor<10xf32>
+//       CHECK:    tensor.insert_slice %{{.*}}[3] [3] [1] : tensor<3xf32> into tensor<10xf32>
+//       CHECK:    %[[CONCAT:.+]] = tensor.insert_slice %{{.*}}[6] [4] [1] : tensor<4xf32> into tensor<10xf32>
 //       CHECK:    return %[[CONCAT]] : tensor<10xf32>
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 426372f0ba4e746..9b6c2327879cf9b 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -49,7 +49,7 @@ func.func @concat_element_type_mismatch(%arg0: tensor<3xf32>, %arg1: tensor<3xi3
 // -----
 
 func.func @concat_incompatible_input_types(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) {
-  // expected-error at +1 {{failed to infer concatenation result type from inputs}}
+  // expected-error at +1 {{static concatenation size mismatch along non-concatenated dimension 1}}
   %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<7x5xf32>
   return
 }
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index bcddfd313ef0484..3e142155df8d9b8 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -77,10 +77,6 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
       llvm::cl::init(false)};
 
-  Option<bool> testDecomposeConcat{
-      *this, "test-decompose-concat",
-      llvm::cl::desc("Test decomposing tensor.concat"), llvm::cl::init(false)};
-
   Option<bool> useForeach{
       *this, "use-foreach",
       llvm::cl::desc(
@@ -112,12 +108,6 @@ static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
-static void applyDecomposeConcatPatterns(Operation *rootOp) {
-  RewritePatternSet patterns(rootOp->getContext());
-  tensor::populateDecomposeTensorConcatPatterns(patterns);
-  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
-}
-
 static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::ControlConstantExtractSliceFusionFn controlFn =
@@ -398,8 +388,6 @@ void TestTensorTransforms::runOnOperation() {
     applyReassociativeReshapeFoldingPatterns(rootOp);
   if (testFoldIntoPackAndUnpack)
     applyFoldIntoPackAndUnpackPatterns(rootOp);
-  if (testDecomposeConcat)
-    applyDecomposeConcatPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {
     if (failed(
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))

>From 2dc1694c71ba6e0e4c13b5f65c292f8c12f03e83 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Fri, 1 Dec 2023 14:36:57 -0500
Subject: [PATCH 4/5] Address comments

---
 mlir/lib/Dialect/Tensor/IR/TensorOps.cpp                | 6 +++---
 mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp | 4 ++--
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 27b54b043becf0d..02146e8257b38e3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -487,7 +487,7 @@ RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
   // The concatenation dim must be in the range [0, rank).
   assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
 
-  SmallVector<int64_t> sizes((concatRank));
+  SmallVector<int64_t> sizes(concatRank);
   for (int64_t i = 0, e = concatRank; i < e; ++i) {
     if (i == dim)
       continue;
@@ -537,7 +537,7 @@ LogicalResult ConcatOp::verify() {
   if (dim >= resultRank)
     return emitOpError("concatenation dim must be less than the tensor rank");
 
-  SmallVector<int64_t> sizes((resultRank));
+  SmallVector<int64_t> sizes(resultRank);
   for (int64_t i = 0, e = resultRank; i < e; ++i) {
     if (i == dim)
       continue;
@@ -604,7 +604,7 @@ ConcatOp::reifyResultShapes(OpBuilder &builder,
 
   // Take the sum of the input sizes along the concatenated dim.
   AffineExpr sum = builder.getAffineDimExpr(0);
-  SmallVector<OpFoldResult> sizes{
+  SmallVector<OpFoldResult> sizes = {
       builder.create<tensor::DimOp>(init.getLoc(), init, 0).getResult()};
   for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
     sum = sum + builder.getAffineDimExpr(idx + 1);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index c5cb30c3ad13eb8..2108fc591055a82 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -53,8 +53,8 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
 
     // Compute the partial sums for the slice offsets.
     AffineExpr sum = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> partialSums{sum};
-    SmallVector<OpFoldResult> offsetStrides{rewriter.getIndexAttr(0)};
+    SmallVector<AffineExpr> partialSums = {sum};
+    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
     for (auto [idx, input] :
          llvm::enumerate(concatOp.getInputs().drop_back())) {
       sum = sum + rewriter.getAffineDimExpr(idx + 1);



More information about the Mlir-commits mailing list