[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Jack Frankland llvmlistbot at llvm.org
Tue Oct 17 08:25:56 PDT 2023


================
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(linalg-transpose-conv2d-ops))' | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+func.func @conv_2d_nhwc_fhwc(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+  // CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+  // CHECK:    %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+  // CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+  // CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  // CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
----------------
FranklandJack wrote:

Not sure, just what I've done in the past but I think your suggestion is a good one, I've moved them together :) 

https://github.com/llvm/llvm-project/pull/68567


More information about the Mlir-commits mailing list