[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)
Jack Frankland
llvmlistbot at llvm.org
Fri Oct 20 07:53:18 PDT 2023
https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/68567
>From 923b8b7154a20c2677d1b3eddfe5a8718849ad1c Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 25 Sep 2023 19:55:39 +0100
Subject: [PATCH] [mlir][linalg] Add TransposeConv2D Pass
* Add a LinAlg pass to convert 2D convolutions and quantized 2D
convolutions that have the `FHWC` filter channel ordering into a
transpose followed by 2D convolutions that have the `HWCF` channel
ordering.
* Add a lit test to check the semantics of the transformation are
correct for both quantized and unquantized variants.
Signed-off-by: Jack Frankland <jack.frankland at arm.com>
---
mlir/include/mlir/Dialect/Linalg/Passes.h | 4 +
mlir/include/mlir/Dialect/Linalg/Passes.td | 26 +++
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Linalg/Transforms/TransposeConv2D.cpp | 144 +++++++++++++++
.../test/Dialect/Linalg/transpose-conv2d.mlir | 169 ++++++++++++++++++
5 files changed, 344 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
create mode 100644 mlir/test/Dialect/Linalg/transpose-conv2d.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 5f46affe592a2da..96c809f10323922 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,6 +65,10 @@ std::unique_ptr<Pass> createLinalgGeneralizationPass();
/// work on primitive types, if possible.
std::unique_ptr<Pass> createLinalgDetensorizePass();
+/// Create a pass to convert linalg.conv_2d_nhwc_fhwc(_q) to
+/// linalg.conv_2d_nhwc_hwcf(_q).
+std::unique_ptr<Pass> createLinalgTransposeConv2DPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3093604af63e338..d24c43a8dec6c71 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -145,4 +145,30 @@ def LinalgDetensorize : InterfacePass<"linalg-detensorize", "FunctionOpInterface
];
}
+def LinalgTransposeConv2D : Pass<"linalg-transpose-conv2d-ops"> {
+ let summary = "Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf by transposing the filter";
+ let constructor = "mlir::createLinalgTransposeConv2DPass()";
+ let dependentDialects = ["linalg::LinalgDialect"];
+
+ let description = [{
+ This pass converts NHWC Conv2D operations with FHWC channel orderings to NHWC
+ Conv2D operations with HWCF channel orderings.
+
+ Applying a conversion targeting LinAlg from a higher level dialect such as TOSA where
+ filter orderings follow the FHWC convention will result in linalg.conv2d_nhwc_fhwc
+ operations being materialized. Subsequent optimizations such as img2col which may make
+ use of optimized BLAS subroutines such as GEMM require the HWCF ordering.
+
+ By applying the linalg-transpose-conv2d-ops pass a pipeline can optionally convert FHWC
+ filter orderings to HWCF thereby allowing them to benefit from subsequent optimizations
+ such as img2col. Conversely targets for which the FHWC ordering is more beneficial can
+ choose not to run this pass and don't need to revert the transformation since it is not
+ part of the "higher level dialect" -> LinAlg conversion passes.
+
+ Currently this pass only support FHWC->HWCF transpositions but could be extended to
+ support higher dimensional tensors and configurable reverse transpositions such as
+ HWCF->FHWC.
+ }];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..823b7bfd9810804 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Tiling.cpp
TilingInterfaceImpl.cpp
Transforms.cpp
+ TransposeConv2D.cpp
Vectorization.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
new file mode 100644
index 000000000000000..8236b8e95ba1970
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -0,0 +1,144 @@
+//===- TransposeConv2D.cpp - Convolution transposition -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/RWMutex.h"
+#include <memory>
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSECONV2D
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+// clang-format off
+/// Convolution converter that applies the following rewrite:
+///
+/// Before:
+///
+/// %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+/// strides = dense<2> : tensor<2xi64>}
+/// ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+/// outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+///
+/// After:
+///
+/// %cst = arith.constant 0.000000e+00 : f32
+/// %0 = tensor.empty() : tensor<2x2x6x8xf32>
+/// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+/// %transposed = linalg.transpose ins(%arg1 : tensor<8x2x2x6xf32>) outs(%1 : tensor<2x2x6x8xf32>)
+/// permutation = [1, 2, 3, 0]
+/// %2 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
+/// ins(%arg0, %transposed : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%arg2 : tensor<1x2x2x8xf32>)
+/// -> tensor<1x2x2x8xf32>
+///
+/// with an analogous example for the quantized case.
+// clang-format on
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+public:
+ using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(FHWCConvOp op,
+ PatternRewriter &rewriter) const final {
+ // Construct a permutation of the filter tensor dimensions. For a 2D
+ // convolution this will be known statically as [1, 2, 3, 0].
+ SmallVector<int64_t> filterPerm({1, 2, 3, 0});
+
+ // Create the type for the transposed filter tensor.
+ auto filter = op->getOperand(1);
+ auto filterTy = cast<ShapedType>(filter.getType());
+ SmallVector<int64_t> newFilterShape(filterPerm.size());
+ std::generate(std::begin(newFilterShape), std::end(newFilterShape),
+ [dim = 0, &filterTy, &filterPerm]() mutable {
+ return filterTy.getShape()[filterPerm[dim++]];
+ });
+
+ // Because linalg.transpose expects an "out" parameter we need to pass it a
+ // tensor of zeros of the result type so here we construct that tensor.
+ auto inputType = op->getOperand(0).getType();
+ auto elementTy = cast<ShapedType>(inputType).getElementType();
+ auto resultZeroAttr = rewriter.getZeroAttr(elementTy);
+ auto loc = op->getLoc();
+
+ const auto isTensorOp = isa<TensorType>(inputType);
+ Value input;
+ if (isTensorOp) {
+
+ input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy)
+ .getResult();
+ } else {
+ input = rewriter
+ .create<memref::AllocOp>(
+ loc, MemRefType::get(newFilterShape, elementTy))
+ .getResult();
+ }
+
+ // We can then construct the transposition on our filter.
+ auto transpose =
+ rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm);
+
+ Value newFilter;
+ if (isTensorOp) {
+ newFilter = transpose.getResult()[0];
+ } else {
+ newFilter = input;
+ }
+
+ SmallVector<Value> newInputs{op.getInputs()};
+ // The filter is always the second input argument, the other inputs can be
+ // left as they are.
+ newInputs[1] = newFilter;
+ // It is possible the convolution doesn't define any results and its
+ // out argument is just used instead.
+ SmallVector<Type> resultTy;
+ if (op.getNumResults()) {
+ resultTy.push_back(op->getResult(0).getType());
+ }
+ rewriter.replaceOpWithNewOp<HWCFConvOp>(op, resultTy, newInputs,
+ op.getOutputs(), op.getStrides(),
+ op.getDilations());
+ return success();
+ }
+};
+
+struct LinalgTransposeConv2D
+ : public impl::LinalgTransposeConv2DBase<LinalgTransposeConv2D> {
+public:
+ void runOnOperation() override {
+ auto *ctx = getOperation()->getContext();
+ auto patternSet = RewritePatternSet{ctx};
+ patternSet.add<
+ ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
+ ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
+ ctx);
+
+ if (failed(applyPatternsAndFoldGreedily(getOperation(),
+ std::move(patternSet))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createLinalgTransposeConv2DPass() {
+ return std::make_unique<LinalgTransposeConv2D>();
+}
diff --git a/mlir/test/Dialect/Linalg/transpose-conv2d.mlir b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
new file mode 100644
index 000000000000000..6143d60557f5bd9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(linalg-transpose-conv2d-ops))' | FileCheck %s
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f64
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf64>, %[[FILTER:.+]]: tensor<8x2x2x6xf64>, %[[INIT:.+]]: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf64>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf64>) outs(%[[NEWF]] : tensor<2x2x6x8xf64>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf64>, tensor<2x2x6x8xf64>) outs(%[[INIT]] : tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xf64>
+func.func @conv_2d_nhwc_fhwc_f64(%input: tensor<1x4x4x6xf64>, %filter: tensor<8x2x2x6xf64>, %init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xf64>, tensor<8x2x2x6xf64>)
+ outs (%init: tensor<1x2x2x8xf64>) -> tensor<1x2x2x8xf64>
+ return %0 : tensor<1x2x2x8xf64>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32>
+func.func @conv_2d_nhwc_fhwc_f32(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+ outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ return %0 : tensor<1x2x2x8xf32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f16
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf16>, %[[FILTER:.+]]: tensor<8x2x2x6xf16>, %[[INIT:.+]]: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf16>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf16>) outs(%[[NEWF]] : tensor<2x2x6x8xf16>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf16>, tensor<2x2x6x8xf16>) outs(%[[INIT]] : tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xf16>
+func.func @conv_2d_nhwc_fhwc_f16(%input: tensor<1x4x4x6xf16>, %filter: tensor<8x2x2x6xf16>, %init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xf16>, tensor<8x2x2x6xf16>)
+ outs (%init: tensor<1x2x2x8xf16>) -> tensor<1x2x2x8xf16>
+ return %0 : tensor<1x2x2x8xf16>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_b16
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xbf16>, %[[FILTER:.+]]: tensor<8x2x2x6xbf16>, %[[INIT:.+]]: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xbf16>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xbf16>) outs(%[[NEWF]] : tensor<2x2x6x8xbf16>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xbf16>, tensor<2x2x6x8xbf16>) outs(%[[INIT]] : tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xbf16>
+func.func @conv_2d_nhwc_fhwc_b16(%input: tensor<1x4x4x6xbf16>, %filter: tensor<8x2x2x6xbf16>, %init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xbf16>, tensor<8x2x2x6xbf16>)
+ outs (%init: tensor<1x2x2x8xbf16>) -> tensor<1x2x2x8xbf16>
+ return %0 : tensor<1x2x2x8xbf16>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi64>, %[[FILTER:.+]]: tensor<8x2x2x6xi64>, %[[INIT:.+]]: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi64>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi64>) outs(%[[NEWF]] : tensor<2x2x6x8xi64>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi64>, tensor<2x2x6x8xi64>) outs(%[[INIT]] : tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xi64>
+func.func @conv_2d_nhwc_fhwc_i64(%input: tensor<1x4x4x6xi64>, %filter: tensor<8x2x2x6xi64>, %init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xi64>, tensor<8x2x2x6xi64>)
+ outs (%init: tensor<1x2x2x8xi64>) -> tensor<1x2x2x8xi64>
+ return %0 : tensor<1x2x2x8xi64>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_i32
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi32>, %[[FILTER:.+]]: tensor<8x2x2x6xi32>, %[[INIT:.+]]: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi32>) outs(%[[NEWF]] : tensor<2x2x6x8xi32>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi32>, tensor<2x2x6x8xi32>) outs(%[[INIT]] : tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xi32>
+func.func @conv_2d_nhwc_fhwc_i32(%input: tensor<1x4x4x6xi32>, %filter: tensor<8x2x2x6xi32>, %init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xi32>, tensor<8x2x2x6xi32>)
+ outs (%init: tensor<1x2x2x8xi32>) -> tensor<1x2x2x8xi32>
+ return %0 : tensor<1x2x2x8xi32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_i16
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi16>, %[[FILTER:.+]]: tensor<8x2x2x6xi16>, %[[INIT:.+]]: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi16>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi16>) outs(%[[NEWF]] : tensor<2x2x6x8xi16>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi16>, tensor<2x2x6x8xi16>) outs(%[[INIT]] : tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xi16>
+func.func @conv_2d_nhwc_fhwc_i16(%input: tensor<1x4x4x6xi16>, %filter: tensor<8x2x2x6xi16>, %init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xi16>, tensor<8x2x2x6xi16>)
+ outs (%init: tensor<1x2x2x8xi16>) -> tensor<1x2x2x8xi16>
+ return %0 : tensor<1x2x2x8xi16>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_i8
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xi8>, %[[FILTER:.+]]: tensor<8x2x2x6xi8>, %[[INIT:.+]]: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xi8>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xi8>) outs(%[[NEWF]] : tensor<2x2x6x8xi8>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xi8>, tensor<2x2x6x8xi8>) outs(%[[INIT]] : tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xi8>
+func.func @conv_2d_nhwc_fhwc_i8(%input: tensor<1x4x4x6xi8>, %filter: tensor<8x2x2x6xi8>, %init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xi8>, tensor<8x2x2x6xi8>)
+ outs (%init: tensor<1x2x2x8xi8>) -> tensor<1x2x2x8xi8>
+ return %0 : tensor<1x2x2x8xi8>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_q
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32>
+ func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32)
+ outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ return %0 : tensor<1x2x2x8xf32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_unit_stride
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32>
+// CHECK: return %[[CONV]] : tensor<1x3x3x8xf32>
+func.func @conv_2d_nhwc_fhwc_f32_unit_stride(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+ outs (%init: tensor<1x3x3x8xf32>) -> tensor<1x3x3x8xf32>
+ return %0 : tensor<1x3x3x8xf32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_f32_2_dialation
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[FILTER:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+// CHECK-DAG: %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[FILTER]] : tensor<8x2x2x6xf32>) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+// CHECK: return %[[CONV]] : tensor<1x2x2x8xf32>
+func.func @conv_2d_nhwc_fhwc_f32_2_dialation(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+ %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<2> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+ outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+ return %0 : tensor<1x2x2x8xf32>
+}
+
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_memref
+// CHECK-SAME: (%[[INPUT:.+]]: memref<1x4x4x6xf32>, %[[FILTER:.+]]: memref<8x2x2x6xf32>, %[[INIT:.+]]: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> {
+// CHECK-DAG: %[[NEWF:.+]] = memref.alloc() : memref<2x2x6x8xf32>
+// CHECK: linalg.transpose ins(%[[FILTER]] : memref<8x2x2x6xf32>) outs(%[[NEWF]] : memref<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+// CHECK: linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[NEWF]] : memref<1x4x4x6xf32>, memref<2x2x6x8xf32>) outs(%[[INIT]] : memref<1x2x2x8xf32>)
+// CHECK: return %[[INIT]] : memref<1x2x2x8xf32>
+func.func @conv_2d_nhwc_fhwc_memref(%input: memref<1x4x4x6xf32>, %filter: memref<8x2x2x6xf32>, %init: memref<1x2x2x8xf32>) -> memref<1x2x2x8xf32> {
+ linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<2> : tensor<2xi64>}
+ ins (%input, %filter: memref<1x4x4x6xf32>, memref<8x2x2x6xf32>)
+ outs (%init: memref<1x2x2x8xf32>)
+ return %init : memref<1x2x2x8xf32>
+}
More information about the Mlir-commits
mailing list