[Mlir-commits] [mlir] [mlir] [linalg] Add canonicalize pattern to swap transpose with broadcast (PR #97063)

donald chen llvmlistbot at llvm.org
Thu Jul 4 22:08:33 PDT 2024


================
@@ -1096,3 +1096,54 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+// -----
+
+func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
+                                    %init1: tensor<1x2x3x4x5x6xf32>,
+                                    %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
+  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
+  //       CHECK:   return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<2x4x5xf32>)
+      outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
+      outs(%init2 : tensor<1x6x2x3x5x4xf32>)
+      permutation = [0, 5, 1, 2, 4, 3]
+  func.return %transpose : tensor<1x6x2x3x5x4xf32>
+}
+
+// -----
+
+func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
+                                            %init1: tensor<1x?x3x?x5x6xf32>,
+                                            %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold_dynamic
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
+  //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+  //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+  //       CHECK:   %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
+  //       CHECK:   %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
+  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
+  //       CHECK:   return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<?x?x5xf32>)
+      outs(%init1 : tensor<1x?x3x?x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
+      outs(%init2 : tensor<1x3x?x6x5x?xf32>)
+      permutation = [0, 2, 3, 5, 4, 1]
+  func.return %transpose : tensor<1x3x?x6x5x?xf32>
+}
----------------
cxy-1993 wrote:

DONE

https://github.com/llvm/llvm-project/pull/97063


More information about the Mlir-commits mailing list