[Mlir-commits] [mlir] [mlir][tensor] Fix bug in `ConcatOpInterface` (PR #168676)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 18 22:49:02 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Longsheng Mou (CoTinker)
<details>
<summary>Changes</summary>
This PR fixes an issue in `ConcatOpInterface` where `tensor.concat` fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using `OpFoldResult`, avoiding the need to separately handle dynamic and static dimension values. Fixes #<!-- -->162776.
---
Full diff: https://github.com/llvm/llvm-project/pull/168676.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-41)
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+33-7)
``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece418dff..5482cedae71d7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,23 @@ struct ConcatOpInterface
// Extract the dimension for the concat op
uint64_t concatDim = concatOp.getDim();
- bool dynamicConcatDim = false;
SmallVector<OpFoldResult> offsets(tensorType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(tensorType.getRank(),
rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> sizes;
-
- for (const auto &[dimIdx, dimSize] :
- llvm::enumerate(tensorType.getShape())) {
- if (dimSize == ShapedType::kDynamic) {
- auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
- sizes.push_back(dimOp.getResult());
- if (dimIdx == concatDim)
- dynamicConcatDim = true;
- } else {
- sizes.push_back(rewriter.getIndexAttr(dimSize));
- }
- }
-
- int64_t concatDimOffset = 0;
- std::optional<Value> dynamicOffset;
- std::optional<Value> dynamicSize;
- if (dynamicConcatDim) {
- // One or more operands have dynamic size, so we must accumulate the
- // offset with arith ops.
- dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
- }
+ SmallVector<OpFoldResult> sizes =
+ memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+ AffineExpr d0, d1;
+ bindDims(rewriter.getContext(), d0, d1);
+ // Add two integers.
+ auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, d0 + d1,
+ {v1, v2});
+ };
+ OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1159,10 @@ struct ConcatOpInterface
// so the offset on that axis must accumulate through the loop, and the
// size must change to the size of the current operand.
auto operandTensorType = cast<RankedTensorType>(operand.getType());
- int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
- if (dynamicConcatDim) {
- offsets[concatDim] = dynamicOffset.value();
- dynamicSize =
- memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
- .getResult();
- sizes[concatDim] = dynamicSize.value();
- } else {
- sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
- offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
- }
+ offsets[concatDim] = concatDimOffset;
+ OpFoldResult concatDimSize =
+ memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+ sizes[concatDim] = concatDimSize;
// Create a subview of the destination buffer.
auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1177,7 @@ struct ConcatOpInterface
if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
return failure();
- if (dynamicConcatDim) {
- dynamicOffset = arith::AddIOp::create(
- rewriter, loc, dynamicOffset.value(), dynamicSize.value());
- } else {
- concatDimOffset += operandConcatDimSize;
- }
+ concatDimOffset = sum(concatDimOffset, concatDimSize);
}
replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360a29b8f..be8ce20d8f154 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
// CHECK-DAG: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<8x?xf32>
-// CHECK-DAG: %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
// CHECK: %[[ALLOC:.*]] = memref.alloc
// CHECK-SAME: memref<?x?xf32>
// CHECK-DAG: %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK: %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// -----
+// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+ %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+ return %0 : tensor<8x10xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @tensor.splat_dynamic(
// CHECK-SAME: %[[F:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[M:[a-zA-Z0-9_]+]]: index
``````````
</details>
https://github.com/llvm/llvm-project/pull/168676
More information about the Mlir-commits
mailing list