[Mlir-commits] [mlir] [mlir][tensor] Fix bug in `ConcatOpInterface` (PR #168676)
Matthias Springer
llvmlistbot at llvm.org
Sun Nov 23 20:36:22 PST 2025
================
@@ -721,6 +718,35 @@ func.func @tensor.concat_dynamic_nonconcat_dim(%f: tensor<?x?xf32>, %g: tensor<?
// -----
+// CHECK: #[[$sum_map:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+
+// CHECK-LABEL: func @tensor.concat_mixed_dynamic_static(
+// CHECK-SAME: %[[F:.*]]: tensor<8x?xf32>, %[[G:.*]]: tensor<8x?xf32>,
+// CHECK-SAME: %[[H:.*]]: tensor<8x2xf32>)
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_buffer %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_buffer %[[G]]
+// CHECK-DAG: %[[H_MEMREF:.*]] = bufferization.to_buffer %[[H]]
+// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {alignment = 64 : i64} : memref<8x10xf32>
+// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[F_DIM:.*]] = memref.dim %[[F_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, %[[F_DIM]]] [1, 1]
+// CHECK: memref.copy %[[F_MEMREF]], %[[SUBVIEW1]]
+// CHECK: %[[G_DIM:.*]] = memref.dim %[[G_MEMREF]], %[[c1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, %[[F_DIM]]] [8, %[[G_DIM]]] [1, 1]
+// CHECK: memref.copy %[[G_MEMREF]], %[[SUBVIEW2]]
+// CHECK: %[[OFFSET:.*]] = affine.apply #[[$sum_map]]()[%[[F_DIM]], %[[G_DIM]]]
+// CHECK: %[[SUBVIEW3:.*]] = memref.subview %[[ALLOC]][0, %[[OFFSET]]] [8, 2] [1, 1]
+// CHECK: memref.copy %[[H_MEMREF]], %[[SUBVIEW3]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_mixed_dynamic_static(%f: tensor<8x?xf32>, %g: tensor<8x?xf32>, %h: tensor<8x2xf32>) -> tensor<8x10xf32> {
+ %0 = tensor.concat dim(1) %f, %g, %h : (tensor<8x?xf32>, tensor<8x?xf32>, tensor<8x2xf32>) -> tensor<8x10xf32>
----------------
matthias-springer wrote:
Is there a problem with an explicit `tensor.cast`?
It would be nice to have a consistent op design across the tensor dialect. I believe one reason why we chose `input dynamicity == output dynamicity` for `collapse_shape`/`expand_shape` is that we can print better error messages: if there's only one allowable output type, you can print it during verification errors.
https://github.com/llvm/llvm-project/pull/168676
More information about the Mlir-commits
mailing list