[Mlir-commits] [mlir] [mlir][bufferization] implement BufferizableOpInterface for concat op (PR #140171)
Jeremy Kun
llvmlistbot at llvm.org
Thu May 15 20:57:42 PDT 2025
================
@@ -615,6 +615,48 @@ func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
// -----
+// CHECK-LABEL: func @tensor.concat(
+// CHECK-SAME: %[[F:.*]]: tensor<8xf32>)
+// CHECK: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
+// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC_2:.*]] :
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<16xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0] [8] [1]
+// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][8] [8] [1]
+// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat(%f: tensor<8xf32>) -> tensor<16xf32> {
+ %t = tensor.concat dim(0) %f, %f : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32>
+ return %t : tensor<16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @tensor.concat_different_shapes(
+// CHECK-SAME: %[[F:.*]]: tensor<8x4xf32>
+// CHECK-SAME: %[[G:.*]]: tensor<8x5xf32>
+// CHECK-DAG: %[[F_MEMREF:.*]] = bufferization.to_memref %[[F]]
+// CHECK-DAG: %[[G_MEMREF:.*]] = bufferization.to_memref %[[G]]
+// CHECK: memref.copy %[[F_MEMREF]], %[[F_ALLOC:.*]] :
+// CHECK: memref.copy %[[G_MEMREF]], %[[F_ALLOC_2:.*]] :
+// CHECK: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<8x9xf32>
+// CHECK: %[[SUBVIEW1:.*]] = memref.subview %[[ALLOC]][0, 0] [8, 4] [1, 1]
+// CHECK: memref.copy %[[F_ALLOC]], %[[SUBVIEW1]]
+// CHECK: %[[SUBVIEW2:.*]] = memref.subview %[[ALLOC]][0, 4] [8, 5] [1, 1]
+// CHECK: memref.copy %[[F_ALLOC_2]], %[[SUBVIEW2]]
+// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[ALLOC]]
+// CHECK: return %[[RET]]
+// CHECK: }
+func.func @tensor.concat_different_shapes(%f: tensor<8x4xf32>, %g: tensor<8x5xf32>) -> tensor<8x9xf32> {
+ %t = tensor.concat dim(1) %f, %g : (tensor<8x4xf32>, tensor<8x5xf32>) -> tensor<8x9xf32>
----------------
j2kun wrote:
Done
https://github.com/llvm/llvm-project/pull/140171
More information about the Mlir-commits
mailing list