[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Jack Frankland llvmlistbot at llvm.org
Mon Oct 23 08:38:16 PDT 2023


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/68567

>From a7c5deff9448d5a1108f5acd4706515e19908944 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 25 Sep 2023 19:55:39 +0100
Subject: [PATCH] [mlir][linalg] Add TransposeConv2D transform op

* Add a linalg transform op to convert 2D convolutions and quantized 2D
  convolutions that have the `FHWC` filter channel ordering into a
  transpose followed by 2D convolutions that have the `HWCF` channel
  ordering.

* Add a lit test to check the semantics of the transformation are
  correct for both quantized and unquantized variants.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
 .../Linalg/TransformOps/LinalgTransformOps.td |  50 ++++
 .../Dialect/Linalg/Transforms/Transforms.h    |   7 +
 .../TransformOps/LinalgTransformOps.cpp       |  27 ++
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/TransposeConv2D.cpp     | 150 +++++++++++
 .../test/Dialect/Linalg/transpose-conv2d.mlir | 254 ++++++++++++++++++
 6 files changed, 489 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/transpose-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 1ff88d036bc036c..a7ec18a618c1b1a 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2210,6 +2210,56 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Transpose Conv2D
+//===----------------------------------------------------------------------===//
+
+def TransposeConv2DOp : Op<Transform_Dialect,
+    "structured.transpose_conv2d",
+    [FunctionalStyleTransformOpTrait,
+     MemoryEffectsOpInterface,
+     TransformOpInterface,
+     TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Convert linalg.conv_2d_nhwc_fhwc into linalg.conv_2d_nhwc_hwcf by introducing
+    a linalg.transpose on the filter tensor/memref.
+
+    Whilst the fhwc filter channel ordering can be desirable for certain targets
+    and is a more direct mapping to higher level dialects such as TOSA (which only
+    supports this ordering) hwcf is better suited for transformations such as
+    img2col which can make use of optimized BLAS routines such as GEMM.
+
+    Returns one handle:
+    - The final operation of the sequence that replaces the original
+      convolution.
+
+    #### Return modes:
+
+    Returns a definite failure if target is not isolated from above.
+    Returns a silenceable failure if the pattern application failed.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+
 //===----------------------------------------------------------------------===//
 // HoistRedundantTensorSubsetsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 3597209d7f90c25..6348edce513be13 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1199,6 +1199,13 @@ rewriteInIm2Col(RewriterBase &rewriter,
 FailureOr<std::pair<Operation *, Operation *>>
 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp);
 
+/// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by
+/// materializing transpose.
+FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
+                                       linalg::Conv2DNhwcFhwcOp op);
+FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
+                                       linalg::Conv2DNhwcFhwcQOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8508507871d0c6c..07ac827be77a174 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3139,6 +3139,33 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TransposeConv2DOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
+    transform::TransformRewriter &rewriter, linalg::LinalgOp target,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  rewriter.setInsertionPoint(target);
+  auto maybeTransformed =
+      TypeSwitch<Operation *, FailureOr<Operation *>>(target)
+          .Case([&](linalg::Conv2DNhwcFhwcOp op) {
+            return transposeConv2D(rewriter, op);
+          })
+          .Case([&](linalg::Conv2DNhwcFhwcQOp op) {
+            return transposeConv2D(rewriter, op);
+          })
+          .Default([&](Operation *op) {
+            return rewriter.notifyMatchFailure(op, "not supported");
+          });
+  if (failed(maybeTransformed))
+    return emitDefaultSilenceableFailure(target);
+  // Handle to the new Conv2D operation with transposed filters
+  results.push_back(*maybeTransformed);
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // HoistRedundantTensorSubsetsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..823b7bfd9810804 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   Tiling.cpp
   TilingInterfaceImpl.cpp
   Transforms.cpp
+  TransposeConv2D.cpp
   Vectorization.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
new file mode 100644
index 000000000000000..9e0829ee67c0136
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -0,0 +1,150 @@
+//===- TransposeConv2D.cpp - Convolution transposition  -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/RWMutex.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+namespace linalg {
+namespace {
+// clang-format off
+/// Convolution converter that applies the following rewrite:
+///
+/// Before:
+///
+///   %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+///                                               strides = dense<2> : tensor<2xi64>}
+///      ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+///     outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+///
+/// After:
+///
+///    %cst = arith.constant 0.000000e+00 : f32
+///    %0 = tensor.empty() : tensor<2x2x6x8xf32>
+///    %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+///    %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
+///                  permutation = [1, 2, 3, 0]
+///    %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+///         ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
+///         -> tensor<1x2x2x8xf32>
+///
+/// with an analogous example for the quantized case.
+// clang-format on
+template <typename FHWCConvOp, typename HWCFConvOp>
+FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter,
+                                             FHWCConvOp op) {
+  // Construct a permutation of the filter tensor dimensions. For a 2D
+  // convolution this will be known statically as [1, 2, 3, 0].
+  SmallVector<int64_t> filterPerm({1, 2, 3, 0});
+
+  // Create the type for the transposed filter tensor.
+  auto filter = op->getOperand(1);
+  auto filterTy = cast<ShapedType>(filter.getType());
+  SmallVector<int64_t> newFilterShape(filterPerm.size());
+  std::generate(std::begin(newFilterShape), std::end(newFilterShape),
+                [dim = 0, &filterTy, &filterPerm]() mutable {
+                  return filterTy.getShape()[filterPerm[dim++]];
+                });
+
+  // Because linalg.transpose expects an "out" parameter we need to pass it a
+  // tensor of zeros of the result type so here we construct that tensor.
+  auto inputType = op->getOperand(0).getType();
+  auto elementTy = cast<ShapedType>(inputType).getElementType();
+  auto loc = op->getLoc();
+
+  const auto isTensorOp = isa<TensorType>(inputType);
+  Value input;
+  if (isTensorOp) {
+
+    input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
+                .getResult();
+  } else {
+    input = rewriter
+                .create<memref::AllocOp>(
+                    loc, MemRefType::get(newFilterShape, elementTy))
+                .getResult();
+  }
+
+  // We can then construct the transposition on our filter.
+  auto transpose =
+      rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
+
+  Value newFilter;
+  if (isTensorOp) {
+    newFilter = transpose.getResult()[0];
+  } else {
+    newFilter = input;
+  }
+
+  SmallVector<Value> newInputs{op.getInputs()};
+  // The filter is always the second input argument, the other inputs can be
+  // left as they are.
+  newInputs[1] = newFilter;
+  // It is possible the convolution doesn't define any results and its
+  // out argument is just used instead.
+  SmallVector<Type> resultTy;
+  if (op.getNumResults()) {
+    resultTy.push_back(op->getResult(0).getType());
+  }
+  auto newConv =
+      rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(),
+                                  op.getStrides(), op.getDilations());
+  rewriter.replaceOp(op, newConv);
+  return newConv.getOperation();
+}
+
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    if (failed(transposeConv2DHelper<FHWCConvOp, HWCFConvOp>(rewriter, op))) {
+      return failure();
+    }
+    return success();
+  }
+};
+} // namespace
+
+FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
+                                       linalg::Conv2DNhwcFhwcOp op) {
+
+  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcOp,
+                               linalg::Conv2DNhwcHwcfOp>(rewriter, op);
+}
+
+FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
+                                       linalg::Conv2DNhwcFhwcQOp op) {
+
+  return transposeConv2DHelper<linalg::Conv2DNhwcFhwcQOp,
+                               linalg::Conv2DNhwcHwcfQOp>(rewriter, op);
+}
+
+void populateTranposeConv2DPatterns(RewritePatternSet &patterns) {
+  MLIRContext *context = patterns.getContext();
+  patterns.insert<
+      ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
+      ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
+      context);
+}
+} // namespace linalg
+} // namespace mlir
diff --git a/mlir/test/Dialect/Linalg/transpose-conv2d.mlir b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
new file mode 100644
index 000000000000000..ff1d050a311363c
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
@@ -0,0 +1,254 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f64
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf64>, %[[FILTER:.+]]: tensor<8x2x2x6xf64>, %[[INIT:.+]]: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf64>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf64>) outs(%[[NEWF]] : tensor<2x2x6x8xf64>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf64>, tensor<2x2x6x8xf64>) outs(%[[INIT]] : tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf64>
+func.func @conv_2d_nhwc_fhwc_f64(%input: tensor<1x4x4x6xf64>, %filter: tensor<8x2x2x6xf64>, %init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf64>, tensor<8x2x2x6xf64>)
+    outs (%init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64>
+  return %0 : tensor<1x2x2x8xf64>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+func.func @conv_2d_nhwc_fhwc_f32(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f16
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf16>, %[[FILTER:.+]]: tensor<8x2x2x6xf16>, %[[INIT:.+]]: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf16>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf16>) outs(%[[NEWF]] : tensor<2x2x6x8xf16>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf16>, tensor<2x2x6x8xf16>) outs(%[[INIT]] : tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf16>
+func.func @conv_2d_nhwc_fhwc_f16(%input: tensor<1x4x4x6xf16>, %filter: tensor<8x2x2x6xf16>, %init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf16>, tensor<8x2x2x6xf16>)
+    outs (%init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16>
+  return %0 : tensor<1x2x2x8xf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_b16
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xbf16>, %[[FILTER:.+]]: tensor<8x2x2x6xbf16>, %[[INIT:.+]]: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xbf16>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xbf16>) outs(%[[NEWF]] : tensor<2x2x6x8xbf16>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xbf16>, tensor<2x2x6x8xbf16>) outs(%[[INIT]] : tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xbf16>
+func.func @conv_2d_nhwc_fhwc_b16(%input: tensor<1x4x4x6xbf16>, %filter: tensor<8x2x2x6xbf16>, %init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xbf16>, tensor<8x2x2x6xbf16>)
+    outs (%init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16>
+  return %0 : tensor<1x2x2x8xbf16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi64>, %[[FILTER:.+]]: tensor<8x2x2x6xi64>, %[[INIT:.+]]: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi64>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi64>) outs(%[[NEWF]] : tensor<2x2x6x8xi64>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi64>, tensor<2x2x6x8xi64>) outs(%[[INIT]] : tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi64>
+func.func @conv_2d_nhwc_fhwc_i64(%input: tensor<1x4x4x6xi64>, %filter: tensor<8x2x2x6xi64>, %init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xi64>, tensor<8x2x2x6xi64>)
+    outs (%init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64>
+  return %0 : tensor<1x2x2x8xi64>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_i32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi32>, %[[FILTER:.+]]: tensor<8x2x2x6xi32>, %[[INIT:.+]]: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi32>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi32>) outs(%[[NEWF]] : tensor<2x2x6x8xi32>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi32>, tensor<2x2x6x8xi32>) outs(%[[INIT]] : tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi32>
+func.func @conv_2d_nhwc_fhwc_i32(%input: tensor<1x4x4x6xi32>, %filter: tensor<8x2x2x6xi32>, %init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xi32>, tensor<8x2x2x6xi32>)
+    outs (%init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32>
+  return %0 : tensor<1x2x2x8xi32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_i16
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi16>, %[[FILTER:.+]]: tensor<8x2x2x6xi16>, %[[INIT:.+]]: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi16>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi16>) outs(%[[NEWF]] : tensor<2x2x6x8xi16>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi16>, tensor<2x2x6x8xi16>) outs(%[[INIT]] : tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi16>
+func.func @conv_2d_nhwc_fhwc_i16(%input: tensor<1x4x4x6xi16>, %filter: tensor<8x2x2x6xi16>, %init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xi16>, tensor<8x2x2x6xi16>)
+    outs (%init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16>
+  return %0 : tensor<1x2x2x8xi16>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_i8
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi8>, %[[FILTER:.+]]: tensor<8x2x2x6xi8>, %[[INIT:.+]]: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi8>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi8>) outs(%[[NEWF]] : tensor<2x2x6x8xi8>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi8>, tensor<2x2x6x8xi8>) outs(%[[INIT]] : tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xi8>
+func.func @conv_2d_nhwc_fhwc_i8(%input: tensor<1x4x4x6xi8>, %filter: tensor<8x2x2x6xi8>, %init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xi8>, tensor<8x2x2x6xi8>)
+    outs (%init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8>
+  return %0 : tensor<1x2x2x8xi8>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_q
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+  func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc_q"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_unit_stride
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32>
+// CHECK:    return %[[CONV]] : tensor<1x3x3x8xf32>
+func.func @conv_2d_nhwc_fhwc_f32_unit_stride(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+    outs (%init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32>
+  return %0 : tensor<1x3x3x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_2_dialation
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+// CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+func.func @conv_2d_nhwc_fhwc_f32_2_dialation(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<2> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}
+
+// -----
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<1x4x4x6xf32>, %[[FILTER:.+]]: memref<8x2x2x6xf32>, %[[INIT:.+]]: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> {
+// CHECK-DAG:    %[[NEWF:.+]] = memref.alloc() : memref<2x2x6x8xf32>
+// CHECK:    linalg.transpose ins(%[[FILTER]] : memref<8x2x2x6xf32>) outs(%[[NEWF]] : memref<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK:    linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[NEWF]] : memref<1x4x4x6xf32>, memref<2x2x6x8xf32>) outs(%[[INIT]] : memref<1x2x2x8xf32>)
+// CHECK:    return %[[INIT]] : memref<1x2x2x8xf32>
+func.func @conv_2d_nhwc_fhwc_memref(%input: memref<1x4x4x6xf32>, %filter: memref<8x2x2x6xf32>, %init: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> {
+  linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                            strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: memref<1x4x4x6xf32>, memref<8x2x2x6xf32>)
+    outs (%init: memref<1x2x2x8xf32>)
+  return %init : memref<1x2x2x8xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !transform.any_op):
+  %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+  %1 = transform.structured.transpose_conv2d %0 : (!transform.any_op) -> (!transform.any_op)
+}



More information about the Mlir-commits mailing list