[Mlir-commits] [mlir] [mlir][tensor] Add a tensor.concat operation (PR #72779)

Han-Chung Wang llvmlistbot at llvm.org
Mon Nov 27 15:18:35 PST 2023


================
@@ -15,6 +15,23 @@ func.func @cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?x?
 
 // -----
 
+// CHECK-LABEL: func @concat(
+func.func @concat(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor<?x?x?xf32>) {
+  // CHECK: tensor.concat dim(0) %{{.*}} : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  %0 = tensor.concat dim(0) %arg0 : (tensor<4x7x3xf32>) -> tensor<4x7x3xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32>
+  // CHECK: tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  %3 = tensor.concat dim(1) %arg2, %arg2 : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x10x?xf32>
+  // CHECK: tensor.concat dim(1) %{{.*}} : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
+  %4 = tensor.concat dim(1) %arg2, %arg1, %arg0 : (tensor<?x?x?xf32>, tensor<4x4x3xf32>, tensor<4x7x3xf32>) -> tensor<4x?x3xf32>
----------------
hanhanW wrote:

[optional] The test looks interesting because we are able to infer some shapes! Perhaps we can have a follow-up that inserts some tensor.cast ops around for helping shape inference. We have [similar features](https://github.com/llvm/llvm-project/blob/fea023b129190edeb503dbe947034f925bbda666/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp#L2264) in Linalg dialect. Do you consider to add such inference to tensor.concat canonicalization patterns?


https://github.com/llvm/llvm-project/pull/72779


More information about the Mlir-commits mailing list