[Mlir-commits] [mlir] [mlir][tensor] Implement TilingInterface for tensor.concat (PR #177786)
Tomer Solomon
llvmlistbot at llvm.org
Sat Jan 24 10:48:52 PST 2026
https://github.com/recursion-man updated https://github.com/llvm/llvm-project/pull/177786
>From f99f3a519410ff796d1c01d9527b0d2d5702eb1b Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomer.solomon at mobileye.com>
Date: Sat, 24 Jan 2026 20:02:24 +0200
Subject: [PATCH] [mlir][tensor] Implement TilingInterface for tensor.concat
- This patch implements TilingInterface for tensor.concat.
Motivation: Currently, tensor.concat acts as a hard barrier for tile-and-fuse transformations. To fuse producers through a concat, we are currently forced to lower the concat into tensor.insert_slice operations before tiling. However, lowering to insert_slice early can be undesirable because it breaks the tensor.concat abstraction. By implementing TilingInterface, we can tile the tensor.concat "in place." This creates a tiled tensor.concat operating on smaller slices, preserving the operation semantics while enabling fusion with producers and consumers.
- Implementation Logic The implementation supports tiling on all dimensions except the concatenation dimension itself. Non-concat dimensions are tiled normally. The offsets and sizes are propagated 1:1 to the input operands.
Tiling along the concat dimension axis is not supported in this implementation, since it could map to partial or strided slices of the inputs, which does not map cleanly to tensor.extract_slice.
---
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 133 ++++++++++++-
.../tile-concat-using-interface.mlir | 175 ++++++++++++++++++
2 files changed, 307 insertions(+), 1 deletion(-)
create mode 100644 mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 124a63281a37c..85aa38eb700be 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -13,8 +13,10 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::tensor;
@@ -74,12 +76,140 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
iterDomainSizes.assign(sizes.begin(), sizes.end());
return success();
}
+};
+
+//===----------------------------------------------------------------------===//
+// ConcatOpTiling
+//===----------------------------------------------------------------------===//
+//
+// Tiling implementation for tensor.concat.
+//
+// The concatenation dimension is not tiled because a tile along that axis
+// could span multiple input tensors, requiring non-contiguous (strided) slices
+// from each contributing input. All other dimensions can be tiled freely since
+// they produce contiguous slices from each input tensor.
+//
+//===----------------------------------------------------------------------===//
+
+struct ConcatOpTiling
+ : public TilingInterface::ExternalModel<ConcatOpTiling, ConcatOp> {
+
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ auto concatOp = cast<ConcatOp>(op);
+ return SmallVector<utils::IteratorType>(concatOp.getRank(),
+ utils::IteratorType::parallel);
+ }
+
+ SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
+ ReifiedRankedShapedTypeDims reifiedShapes;
+ (void)reifyResultShapes(b, op, reifiedShapes);
+ OpFoldResult zero = b.getIndexAttr(0);
+ OpFoldResult one = b.getIndexAttr(1);
+
+ SmallVector<Range> loopBounds(reifiedShapes[0].size(), {zero, one, one});
+ for (const auto &ub : enumerate(reifiedShapes[0]))
+ loopBounds[ub.index()].size = ub.value();
+ return loopBounds;
+ }
+
+ FailureOr<TilingResult>
+ getTiledImplementation(Operation *op, OpBuilder &b,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes) const {
+ auto concatOp = cast<ConcatOp>(op);
+ Location loc = concatOp.getLoc();
+ int64_t concatDim = concatOp.getDim();
+
+ // Check offset first (doesn't create ops) to fail fast.
+ if (!isZeroInteger(offsets[concatDim]))
+ return failure();
+
+ // Get the full size of the concat dimension from the result.
+ OpFoldResult concatDimSize =
+ tensor::getMixedSize(b, loc, concatOp.getResult(), concatDim);
+
+ // Verify that the tile size equals the full concat dimension size.
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sizes[concatDim], concatDimSize);
+ if (failed(maybeEqual) || !maybeEqual.value())
+ return failure();
+ int64_t rank = concatOp.getRank();
+ OpFoldResult one = b.getIndexAttr(1);
+
+ SmallVector<Operation *> generatedSlices;
+ SmallVector<Value> tiledInputs;
+ tiledInputs.reserve(concatOp.getNumOperands());
+
+ // Slice each input tensor on all non-concat dimensions.
+ for (Value input : concatOp.getInputs()) {
+ SmallVector<OpFoldResult> inputOffsets(rank);
+ SmallVector<OpFoldResult> inputSizes(rank);
+ SmallVector<OpFoldResult> inputStrides(rank, one);
+
+ for (int64_t dim = 0; dim < rank; ++dim) {
+ if (dim == concatDim) {
+ // Keep the full extent of the concat dimension for each input.
+ inputOffsets[dim] = b.getIndexAttr(0);
+ inputSizes[dim] = tensor::getMixedSize(b, loc, input, dim);
+ } else {
+ // Apply tile offsets and sizes on non-concat dimensions.
+ inputOffsets[dim] = offsets[dim];
+ inputSizes[dim] = sizes[dim];
+ }
+ }
+
+ auto extractSlice = tensor::ExtractSliceOp::create(
+ b, loc, input, inputOffsets, inputSizes, inputStrides);
+ generatedSlices.push_back(extractSlice);
+ tiledInputs.push_back(extractSlice.getResult());
+ }
+
+ // Create the tiled concat operation.
+ auto tiledConcat = ConcatOp::create(b, loc, concatDim, tiledInputs);
+
+ return TilingResult{{tiledConcat.getOperation()},
+ {tiledConcat.getResult()},
+ generatedSlices};
+ }
+
+ LogicalResult
+ getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVector<OpFoldResult> &resultOffsets,
+ SmallVector<OpFoldResult> &resultSizes) const {
+ // For concat, the result tile position is the same as the iteration
+ // domain tile since we're directly mapping iteration space to output.
+ resultOffsets.assign(offsets.begin(), offsets.end());
+ resultSizes.assign(sizes.begin(), sizes.end());
+ return success();
+ }
+
+ LogicalResult getIterationDomainTileFromResultTile(
+ Operation *op, OpBuilder &b, unsigned resultNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+ // For concat, the iteration domain is the same as the result domain.
+ // This is a 1-to-1 mapping since the result shape equals the iteration
+ // domain.
+ iterDomainOffsets.assign(offsets.begin(), offsets.end());
+ iterDomainSizes.assign(sizes.begin(), sizes.end());
+ return success();
+ }
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) const {
- return getTiledImplementation(op, b, offsets, sizes);
+ // Map result tile to iteration domain tile and generate tiled impl.
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+ if (failed(getIterationDomainTileFromResultTile(
+ op, b, resultNumber, offsets, sizes, iterDomainOffsets,
+ iterDomainSizes)))
+ return failure();
+
+ return getTiledImplementation(op, b, iterDomainOffsets, iterDomainSizes);
}
};
@@ -312,5 +442,6 @@ void mlir::tensor::registerTilingInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
tensor::PadOp::attachInterface<PadOpTiling>(*ctx);
+ tensor::ConcatOp::attachInterface<ConcatOpTiling>(*ctx);
});
}
diff --git a/mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir
new file mode 100644
index 0000000000000..b2be4d412932d
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-concat-using-interface.mlir
@@ -0,0 +1,175 @@
+// RUN: mlir-opt -transform-interpreter -cse -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// Test tiling tensor.concat on non-concat dimensions (dimension 0 tiled).
+
+func.func @concat_tile_dim0(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x16xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x16xf32>
+ return %0 : tensor<4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile dimension 0 (non-concat), keep dimension 1 (concat) untiled.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [2, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_tile_dim0(
+// CHECK: scf.for
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.concat dim(1)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test tiling tensor.concat with dynamic shapes (non-concat dimension is dynamic).
+
+func.func @concat_tile_dynamic(%arg0: tensor<?x8xf32>, %arg1: tensor<?x8xf32>) -> tensor<?x16xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<?x8xf32>, tensor<?x8xf32>) -> tensor<?x16xf32>
+ return %0 : tensor<?x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [2, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_tile_dynamic(
+// CHECK: tensor.dim
+// CHECK: scf.for
+// CHECK: affine.min
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.concat dim(1)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test tiling tensor.concat when concat is on dim 0 and we tile dim 1.
+
+func.func @concat_on_dim0_tile_dim1(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<8x8xf32> {
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<8x8xf32>
+ return %0 : tensor<8x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile only dimension 1 (non-concat), dimension 0 (concat) is untiled.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [0, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_on_dim0_tile_dim1(
+// CHECK: scf.for
+// CHECK: tensor.extract_slice {{.*}}[0,
+// CHECK: tensor.extract_slice {{.*}}[0,
+// CHECK: tensor.concat dim(0)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test tiling tensor.concat with 3 inputs.
+
+func.func @concat_three_inputs(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<4x8xf32>) -> tensor<4x24xf32> {
+ %0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x8xf32>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x24xf32>
+ return %0 : tensor<4x24xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [2, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_three_inputs(
+// CHECK: scf.for
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: tensor.concat dim(1)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Test 3D tensor concat tiling multiple dimensions.
+
+func.func @concat_3d_tile_two_dims(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x32xf32> {
+ %0 = tensor.concat dim(2) %arg0, %arg1 : (tensor<4x8x16xf32>, tensor<4x8x16xf32>) -> tensor<4x8x32xf32>
+ return %0 : tensor<4x8x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Tile dimensions 0 and 1 (non-concat), keep dimension 2 (concat) untiled.
+ %tiled, %loop0, %loop1 = transform.structured.tile_using_for %concat tile_sizes [2, 4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-LABEL: func.func @concat_3d_tile_two_dims(
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: tensor.extract_slice {{.*}} [2, 4, 16]
+// CHECK: tensor.extract_slice {{.*}} [2, 4, 16]
+// CHECK: tensor.concat dim(2)
+// CHECK: tensor.insert_slice
+
+// -----
+
+// Negative test: tiling the concat dimension should fail.
+
+func.func @concat_tile_concat_dim_fail(%arg0: tensor<4x8xf32>, %arg1: tensor<4x8xf32>) -> tensor<4x16xf32> {
+ // expected-error @below {{faild to tile operation}}
+ // expected-error @below {{failed to generate tiling loops}}
+ %0 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x16xf32>
+ return %0 : tensor<4x16xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Attempt to tile the concat dimension (dim 1) - should fail.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [0, 4]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// Negative test: tiling a dynamic concat dimension should fail.
+
+func.func @concat_tile_dynamic_concat_dim_fail(%arg0: tensor<?x8xf32>, %arg1: tensor<?x8xf32>) -> tensor<?x8xf32> {
+ // expected-error @below {{faild to tile operation}}
+ // expected-error @below {{failed to generate tiling loops}}
+ %0 = tensor.concat dim(0) %arg0, %arg1 : (tensor<?x8xf32>, tensor<?x8xf32>) -> tensor<?x8xf32>
+ return %0 : tensor<?x8xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %concat = transform.structured.match ops{["tensor.concat"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ // Attempt to tile the dynamic concat dimension (dim 0) - should fail.
+ %tiled, %loop = transform.structured.tile_using_for %concat tile_sizes [4, 0]
+ : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list