[Mlir-commits] [mlir] [mlir][tensor] Implement TilingInterface for tensor.concat (PR #177786)
Tomer Solomon
llvmlistbot at llvm.org
Sat Jan 24 10:55:25 PST 2026
https://github.com/recursion-man updated https://github.com/llvm/llvm-project/pull/177786
>From 842757fe0c9863538dff1a98dc61da9dbff9f773 Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomer.solomon at mobileye.com>
Date: Sat, 24 Jan 2026 20:02:24 +0200
Subject: [PATCH] [mlir][tensor] Implement TilingInterface for tensor.concat
- This patch implements TilingInterface for tensor.concat.
Motivation: Currently, tensor.concat acts as a hard barrier for tile-and-fuse transformations. To fuse producers through a concat, we are currently forced to lower the concat into tensor.insert_slice operations before tiling. However, lowering to insert_slice early can be undesirable because it breaks the tensor.concat abstraction. By implementing TilingInterface, we can tile the tensor.concat "in place." This creates a tiled tensor.concat operating on smaller slices, preserving the operation semantics while enabling fusion with producers and consumers.
- Implementation Logic The implementation supports tiling on all dimensions except the concatenation dimension itself. Non-concat dimensions are tiled normally. The offsets and sizes are propagated 1:1 to the input operands.
Tiling along the concat dimension axis is not supported in this implementation, since it could map to partial or strided slices of the inputs, which does not map cleanly to tensor.extract_slice.
---
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 124 +++++++++++++
.../tile-concat-using-interface.mlir | 175 ++++++++++++++++++
2 files changed, 299 insertions(+)
create mode 100644 mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 124a63281a37c..36909f9d8288d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -13,8 +13,10 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::tensor;
@@ -83,6 +85,127 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConcatOpTiling
+//===----------------------------------------------------------------------===//
+//
+// Tiling implementation for tensor.concat.
+//
+// The concatenation dimension is not tiled because a tile along that axis
+// could span multiple input tensors, requiring non-contiguous (strided) slices
+// from each contributing input. All other dimensions can be tiled freely since
+// they produce contiguous slices from each input tensor.
+//
+//===----------------------------------------------------------------------===//
+
+struct ConcatOpTiling
+ : public TilingInterface::ExternalModel<ConcatOpTiling, ConcatOp> {
+
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto concatOp = cast<ConcatOp>(op);
+ return SmallVector<utils::IteratorType>(concatOp.getRank(),
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+ ReifiedRankedShapedTypeDims reifiedShapes;
+ (void)reifyResultShapes(b, op, reifiedShapes);
+ OpFoldResult zero = b.getIndexAttr(0);
+ OpFoldResult one = b.getIndexAttr(1);
+
+ SmallVector<Range> loopBounds(reifiedShapes[0].size(), {zero, one, one});
+ for (const auto &ub : enumerate(reifiedShapes[0]))
+ loopBounds[ub.index()].size = ub.value();
+ return loopBounds;
+ }
+
+ FailureOr<TilingResult>
+ getTiledImplementation(Operation *op, OpBuilder &b,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ auto concatOp = cast<ConcatOp>(op);
+ Location loc = concatOp.getLoc();
+ int64_t concatDim = concatOp.getDim();
+
+ // Check offset first (doesn't create ops) to fail fast.
+ if (!isZeroInteger(offsets[concatDim]))
+ return failure();
+
+ // Get the full size of the concat dimension from the result.
+ OpFoldResult concatDimSize =
+ tensor::getMixedSize(b, loc, concatOp.getResult(), concatDim);
+
+ // Verify that the tile size equals the full concat dimension size.
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sizes[concatDim], concatDimSize);
+ if (failed(maybeEqual) || !maybeEqual.value())
+ return failure();
+ int64_t rank = concatOp.getRank();
+ OpFoldResult one = b.getIndexAttr(1);
+
+ SmallVector<Operation *> generatedSlices;
+ SmallVector<Value> tiledInputs;
+ tiledInputs.reserve(concatOp.getNumOperands());
+
+ // Slice each input tensor on all non-concat dimensions.
+ for (Value input : concatOp.getInputs()) {
+ SmallVector<OpFoldResult> inputOffsets(rank);
+ SmallVector<OpFoldResult> inputSizes(rank);
+ SmallVector<OpFoldResult> inputStrides(rank, one);
+
+ for (int64_t dim = 0; dim < rank; ++dim) {
+ if (dim == concatDim) {
+ // Keep the full extent of the concat dimension for each input.
+ inputOffsets[dim] = b.getIndexAttr(0);
+ inputSizes[dim] = tensor::getMixedSize(b, loc, input, dim);
+ } else {
+ // Apply tile offsets and sizes on non-concat dimensions.
+ inputOffsets[dim] = offsets[dim];
+ inputSizes[dim] = sizes[dim];
+ }
+ }
+
+ auto extractSlice = tensor::ExtractSliceOp::create(
+ b, loc, input, inputOffsets, inputSizes, inputStrides);
+ generatedSlices.push_back(extractSlice);
+ tiledInputs.push_back(extractSlice.getResult());
+ }
+
+ // Create the tiled concat operation.
+ auto tiledConcat = ConcatOp::create(b, loc, concatDim, tiledInputs);
+
+ return TilingResult{{tiledConcat.getOperation()},
+ {tiledConcat.getResult()},
+ generatedSlices};
+ }
+
+ LogicalResult
+ getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
+ // For concat, the result tile position is the same as the iteration
+ // domain tile since we're directly mapping iteration space to output.
+ resultOffsets.assign(offsets.begin(), offsets.end());
+ resultSizes.assign(sizes.begin(), sizes.end());
+ return success();
+ }
+
+ LogicalResult getIterationDomainTileFromResultTile(
+ Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+ // For concat, the iteration domain is the same as the result domain.
+ // This is a 1-to-1 mapping since the result shape equals the iteration
+ // domain.
+ iterDomainOffsets.assign(offsets.begin(), offsets.end());
+ iterDomainSizes.assign(sizes.begin(), sizes.end());
+ return success();
+ }
+};
+
} // namespace
FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
@@ -312,5 +435,6 @@ void mlir::tensor::registerTilingInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
+ tensor::ConcatOp::attachInterface<ConcatOpTiling>(*ctx);
});
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir
new file mode 100644
index 0000000000000..b2be4d412932d
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir
@@ -0,0 +1,175 @@
+// RUN: mlir-opt -transform-interpreter -cse -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Test tiling tensor.concat on non-concat dimensions (dimension 0 tiled).
+
+func.func @concat_tile_dim0(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x16xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x16xf32>
+ return %0 : tensor<4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile dimension 0 (non-concat), keep dimension 1 (concat) untiled.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [2, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_tile_dim0(
+// CHECK: scf.for
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.concat dim(1)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test tiling tensor.concat with dynamic shapes (non-concat dimension is dynamic).
+
+func.func @concat_tile_dynamic(%arg0: tensor<?x8xf32>, %arg1: tensor<?x8xf32>) -> tensor<?x16xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<?x8xf32>, tensor<?x8xf32>) -> tensor<?x16xf32>
+ return %0 : tensor<?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [2, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_tile_dynamic(
+// CHECK: tensor.dim
+// CHECK: scf.for
+// CHECK: affine.min
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.concat dim(1)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test tiling tensor.concat when concat is on dim 0 and we tile dim 1.
+
+func.func @concat_on_dim0_tile_dim1(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<8x8xf32> {
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<8x8xf32>
+ return %0 : tensor<8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile only dimension 1 (non-concat), dimension 0 (concat) is untiled.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [0, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_on_dim0_tile_dim1(
+// CHECK: scf.for
+// CHECK: tensor.extract_slice {{.*}}[0,
+// CHECK: tensor.extract_slice {{.*}}[0,
+// CHECK: tensor.concat dim(0)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test tiling tensor.concat with 3 inputs.
+
+func.func @concat_three_inputs(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<4x8xf32>) -> tensor<4x24xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x24xf32>
+ return %0 : tensor<4x24xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [2, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_three_inputs(
+// CHECK: scf.for
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.concat dim(1)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test 3D tensor concat tiling multiple dimensions.
+
+func.func @concat_3d_tile_two_dims(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x32xf32> {
+ %0 = tensor.concat dim(2) %arg0, %arg1 : (tensor<4x8x16xf32>, tensor<4x8x16xf32>) -> tensor<4x8x32xf32>
+ return %0 : tensor<4x8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile dimensions 0 and 1 (non-concat), keep dimension 2 (concat) untiled.
+ %tiled, %loop0, %loop1 = transform.structured.tile_using_for %concat tile_sizes [2, 4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_3d_tile_two_dims(
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: tensor.extract_slice {{.*}} [2, 4, 16]
+// CHECK: tensor.extract_slice {{.*}} [2, 4, 16]
+// CHECK: tensor.concat dim(2)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Negative test: tiling the concat dimension should fail.
+
+func.func @concat_tile_concat_dim_fail(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x16xf32> {
+ // expected-error @below {{faild to tile operation}}
+ // expected-error @below {{failed to generate tiling loops}}
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x16xf32>
+ return %0 : tensor<4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Attempt to tile the concat dimension (dim 1) - should fail.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [0, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// Negative test: tiling a dynamic concat dimension should fail.
+
+func.func @concat_tile_dynamic_concat_dim_fail(%arg0: tensor<?x8xf32>, %arg1: tensor<?x8xf32>) -> tensor<?x8xf32> {
+ // expected-error @below {{faild to tile operation}}
+ // expected-error @below {{failed to generate tiling loops}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<?x8xf32>, tensor<?x8xf32>) -> tensor<?x8xf32>
+ return %0 : tensor<?x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Attempt to tile the dynamic concat dimension (dim 0) - should fail.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list