[Mlir-commits] [mlir] [mlir][tensor] Fix bug in `ConcatOpInterface` (PR #168676)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 18 22:49:02 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

This PR fixes an issue in `ConcatOpInterface` where `tensor.concat` fails when the concat dimension is dynamic while the result type is static. The fix unifies the computation by using `OpFoldResult`, avoiding the need to separately handle dynamic and static dimension values. Fixes #<!-- -->162776.

---
Full diff: https://github.com/llvm/llvm-project/pull/168676.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+16-41) 
- (modified) mlir/test/Dialect/Tensor/bufferize.mlir (+33-7) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c607ece418dff..5482cedae71d7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -1132,35 +1132,23 @@ struct ConcatOpInterface
 
     // Extract the dimension for the concat op
     uint64_t concatDim = concatOp.getDim();
-    bool dynamicConcatDim = false;
 
     SmallVector<OpFoldResult> offsets(tensorType.getRank(),
                                       rewriter.getIndexAttr(0));
     SmallVector<OpFoldResult> strides(tensorType.getRank(),
                                       rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> sizes;
-
-    for (const auto &[dimIdx, dimSize] :
-         llvm::enumerate(tensorType.getShape())) {
-      if (dimSize == ShapedType::kDynamic) {
-        auto dimOp = memref::DimOp::create(rewriter, loc, dstBuffer, dimIdx);
-        sizes.push_back(dimOp.getResult());
-        if (dimIdx == concatDim)
-          dynamicConcatDim = true;
-      } else {
-        sizes.push_back(rewriter.getIndexAttr(dimSize));
-      }
-    }
-
-    int64_t concatDimOffset = 0;
-    std::optional<Value> dynamicOffset;
-    std::optional<Value> dynamicSize;
-    if (dynamicConcatDim) {
-      // One or more operands have dynamic size, so we must accumulate the
-      // offset with arith ops.
-      dynamicOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
-    }
+    SmallVector<OpFoldResult> sizes =
+        memref::getMixedSizes(rewriter, loc, dstBuffer);
+
+    AffineExpr d0, d1;
+    bindDims(rewriter.getContext(), d0, d1);
+    // Add two integers.
+    auto sum = [&](OpFoldResult v1, OpFoldResult v2) {
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, d0 + d1,
+                                                   {v1, v2});
+    };
 
+    OpFoldResult concatDimOffset = rewriter.getIndexAttr(0);
     for (auto operand : concatOp.getInputs()) {
       // Get the buffer for the operand.
       FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
@@ -1171,18 +1159,10 @@ struct ConcatOpInterface
       // so the offset on that axis must accumulate through the loop, and the
       // size must change to the size of the current operand.
       auto operandTensorType = cast<RankedTensorType>(operand.getType());
-      int64_t operandConcatDimSize = operandTensorType.getDimSize(concatDim);
-
-      if (dynamicConcatDim) {
-        offsets[concatDim] = dynamicOffset.value();
-        dynamicSize =
-            memref::DimOp::create(rewriter, loc, *srcBuffer, concatDim)
-                .getResult();
-        sizes[concatDim] = dynamicSize.value();
-      } else {
-        sizes[concatDim] = rewriter.getIndexAttr(operandConcatDimSize);
-        offsets[concatDim] = rewriter.getIndexAttr(concatDimOffset);
-      }
+      offsets[concatDim] = concatDimOffset;
+      OpFoldResult concatDimSize =
+          memref::getMixedSize(rewriter, loc, *srcBuffer, concatDim);
+      sizes[concatDim] = concatDimSize;
 
       // Create a subview of the destination buffer.
       auto dstMemrefType = cast<MemRefType>(memrefType);
@@ -1197,12 +1177,7 @@ struct ConcatOpInterface
       if (failed(options.createMemCpy(rewriter, loc, *srcBuffer, subview)))
         return failure();
 
-      if (dynamicConcatDim) {
-        dynamicOffset = arith::AddIOp::create(
-            rewriter, loc, dynamicOffset.value(), dynamicSize.value());
-      } else {
-        concatDimOffset += operandConcatDimSize;
-      }
+      concatDimOffset = sum(concatDimOffset, concatDimSize);
     }
 
     replaceOpWithBufferizedValues(rewriter, op, dstBuffer);
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index 5eb2360a29b8f..be8ce20d8f154 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -678,11 +678,9 @@ func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf3
 // CHECK-DAG:       %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
 // CHECK:           %[[ALLOC:.*]] = memref.alloc
 // CHECK-SAME:                                    memref<8x?xf32>
-// CHECK-DAG:       %[[OFFSET:.*]] = arith.constant 0 : index
-// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[OFFSET]], %[[F_DIM]] : index
-// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
@@ -706,10 +704,9 @@ func.func @tensor.concat_dynamic(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>) -> te
 // CHECK:           %[[ALLOC:.*]] = memref.alloc
 // CHECK-SAME:                                    memref<?x?xf32>
 // CHECK-DAG:       %[[NON_CONCAT_DIM:.*]] = memref.dim %[[ALLOC]], %[[c0]]
-// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, %[[c0]]] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [%[[NON_CONCAT_DIM]], %[[F_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
-// CHECK:           %[[OFFSET_2:.*]] = arith.addi %[[c0]], %[[F_DIM]] : index
-// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET_2]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [%[[NON_CONCAT_DIM]], %[[G_DIM]]] [1, 1]
 // CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
 // CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
 // CHECK:           return %[[RET]]
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
 
 // -----
 
+// CHECK:  #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL:   func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME:        %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME:        %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG:       %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG:       %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG:       %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG:       %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK:           %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK:           %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK:           %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK:           %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK:           memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK:           %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK:           %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK:           memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK:           %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK:           return %[[RET]]
+// CHECK:         }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+  %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
+  return %0 : tensor<8x10xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @tensor.splat_dynamic(
 // CHECK-SAME:  %[[F:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:  %[[M:[a-zA-Z0-9_]+]]: index

``````````

</details>


https://github.com/llvm/llvm-project/pull/168676


More information about the Mlir-commits mailing list