[Mlir-commits] [mlir] [mlir][linalg] Add TransposeConv2D Pass (PR #68567)

Jack Frankland llvmlistbot at llvm.org
Mon Oct 9 01:59:23 PDT 2023

https://github.com/FranklandJack created https://github.com/llvm/llvm-project/pull/68567

* Add a LinAlg pass to convert 2D convolutions and quantized 2D convolutions that have the `FHWC` filter channel ordering into a transpose followed by 2D convolutions that have the `HWCF` channel ordering.

* Add a lit test to check the semantics of the transformation are correct for both quantized and unquantized variants.

>From 20cb7f183de7df340f96cd29a0a87d1ee691d428 Mon Sep 17 00:00:00 2001
From: Jack Frankland <jack.frankland at arm.com>
Date: Mon, 25 Sep 2023 19:55:39 +0100
Subject: [PATCH] [mlir][linalg] Add TransposeConv2D Pass

* Add a LinAlg pass to convert 2D convolutions and quantized 2D
  convolutions that have the `FHWC` filter channel ordering into a
  transpose followed by 2D convolutions that have the `HWCF` channel

* Add a lit test to check the semantics of the transformation are
  correct for both quantized and unquantized variants.

Signed-off-by: Jack Frankland <jack.frankland at arm.com>
 mlir/include/mlir/Dialect/Linalg/Passes.h     |   4 +
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   6 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Linalg/Transforms/TransposeConv2D.cpp     | 116 ++++++++++++++++++
 .../test/Dialect/Linalg/transpose-conv2d.mlir |  33 +++++
 5 files changed, 160 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
 create mode 100644 mlir/test/Dialect/Linalg/transpose-conv2d.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 5f46affe592a2da..96c809f10323922 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -65,6 +65,10 @@ std::unique_ptr<Pass> createLinalgGeneralizationPass();
 /// work on primitive types, if possible.
 std::unique_ptr<Pass> createLinalgDetensorizePass();
+/// Create a pass to convert linalg.conv_2d_nhwc_fhwc(_q) to
+/// linalg.conv_2d_nhwc_hwcf(_q).
+std::unique_ptr<Pass> createLinalgTransposeConv2DPass();
 // Registration
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3093604af63e338..74cbe0c354f9018 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -145,4 +145,10 @@ def LinalgDetensorize : InterfacePass<"linalg-detensorize", "FunctionOpInterface
+def LinalgTransposeConv2D : Pass<"linalg-transpose-conv2d-ops"> {
+  let summary = "Convert conv_2d_nhwc_fhwc to conv_2d_nhwc_hwcf by transposing the weights.";
+  let constructor = "mlir::createLinalgTransposeConv2DPass()";
+  let dependentDialects = ["linalg::LinalgDialect"];
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4e094609afa6a03..823b7bfd9810804 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
+  TransposeConv2D.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
new file mode 100644
index 000000000000000..a8dee1126031601
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp
@@ -0,0 +1,116 @@
+//===- TransposeConv2D.cpp - Convoultion transposition  -------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include <memory>
+#include <numeric>
+namespace mlir {
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+using namespace mlir;
+namespace {
+// Convolution converter that matches linalg.conv_2d_nhwc_fhwc and
+// linalg.conv_2d_nhwc_fhwc_q to linalg.transpose + linalg.conv_2d_nhwc_hwcf and
+// linalg.tranpose + linalg.conv_2d_nhwc_hwcf_q respectively.
+template <typename FHWCConvOp, typename HWCFConvOp>
+class ConvConverter : public OpRewritePattern<FHWCConvOp> {
+  using OpRewritePattern<FHWCConvOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(FHWCConvOp op,
+                                PatternRewriter &rewriter) const final {
+    // Transpose the weights.
+    //
+    // To do this we first need to construct a permutation of the weight tensor
+    // dimensions. For a 2D convolution this will be known statically as [1, 2,
+    // 3, 0] however we construct the vector dynamically to future proof this
+    // logic so it can be extended to convolutions of higher dimensions.
+    auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+    auto weightPerm = SmallVector<int64_t>(resultTy.getRank() - 1);
+    std::iota(std::begin(weightPerm), std::end(weightPerm), 1);
+    weightPerm.push_back(0);
+    // Create the type for the transposed weight tensor since this will be
+    // different from the original weight type.
+    auto weight = op->getOperand(1);
+    auto weightTy = cast<ShapedType>(weight.getType());
+    auto newWeightShape = SmallVector<int64_t>(weightPerm.size());
+    std::generate(std::begin(newWeightShape), std::end(newWeightShape),
+                  [dim = 0, &weightTy, &weightPerm]() mutable {
+                    return weightTy.getShape()[weightPerm[dim++]];
+                  });
+    auto newWeightTy =
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
+    // Because linalg.tranpose expects an "out" parameter we need to pass it a
+    // tensor of zeros of the result type so here we construct that tensor.
+    auto resultETy = resultTy.getElementType();
+    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto loc = op->getLoc();
+    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
+        loc, newWeightTy.getShape(), resultETy);
+    auto zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
+    auto zeroTensor = rewriter
+                          .create<linalg::FillOp>(loc, ValueRange{zero},
+                                                  ValueRange{emptyTensor})
+                          .result();
+    // We can then construct the transposition on our weights.
+    weight =
+        rewriter
+            .create<linalg::TransposeOp>(loc, weight, zeroTensor, weightPerm)
+            .getResult()[0];
+    // Create the convolution.
+    //
+    // The weights are always the second input argument.
+    auto newInputs = SmallVector<Value>{op.getInputs()};
+    newInputs[1] = weight;
+    rewriter.template replaceOpWithNewOp<HWCFConvOp>(
+        op, resultTy, newInputs, op.getOutputs(), op.getStrides(),
+        op.getDilations());
+    return success();
+  }
+// This pass converts NHWC Conv2D operations with FHWC channel orderings to NHWC
+// Conv2D operations with HWCF channel orderings.
+struct LinalgTransposeConv2D
+    : public impl::LinalgTransposeConv2DBase<LinalgTransposeConv2D> {
+  void runOnOperation() override {
+    auto *ctx = getOperation()->getContext();
+    auto patternSet = RewritePatternSet{ctx};
+    patternSet.add<
+        ConvConverter<linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcHwcfOp>,
+        ConvConverter<linalg::Conv2DNhwcFhwcQOp, linalg::Conv2DNhwcHwcfQOp>>(
+        ctx);
+    if (failed(applyPatternsAndFoldGreedily(getOperation(),
+                                            std::move(patternSet))))
+      return signalPassFailure();
+  }
+} // namespace
+std::unique_ptr<Pass> mlir::createLinalgTransposeConv2DPass() {
+  return std::make_unique<LinalgTransposeConv2D>();
diff --git a/mlir/test/Dialect/Linalg/transpose-conv2d.mlir b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
new file mode 100644
index 000000000000000..22019029a02743d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-conv2d.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(linalg-transpose-conv2d-ops))' | FileCheck %s
+// CHECK-LABEL: @conv_2d_nhwc_fhwc
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+func.func @conv_2d_nhwc_fhwc(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32> {
+  // CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+  // CHECK:    %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+  // CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+  // CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  // CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>
+// CHECK-LABEL: @conv_2d_nhwc_fhwc_q
+// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x4x4x6xf32>, %[[WEIGHTS:.+]]: tensor<8x2x2x6xf32>, %[[INIT:.+]]: tensor<1x2x2x8xf32>, %[[A:.+]]: i32, %[[B:.+]]: i32) -> tensor<1x2x2x8xf32> {
+  func.func @conv_2d_nhwc_fhwc_q(%input: tensor<1x4x4x6xf32>, %filter: tensor<8x2x2x6xf32>, %init: tensor<1x2x2x8xf32>, %a: i32, %b: i32) -> tensor<1x2x2x8xf32> {
+  // CHECK-DAG:    %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+  // CHECK-DAG:    %[[NEWF:.+]] = tensor.empty() : tensor<2x2x6x8xf32>
+  // CHECK:    %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[NEWF]] : tensor<2x2x6x8xf32>) -> tensor<2x2x6x8xf32>
+  // CHECK:    %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[WEIGHTS]] : tensor<8x2x2x6xf32>) outs(%[[FILL]] : tensor<2x2x6x8xf32>) permutation = [1, 2, 3, 0]
+  // CHECK:    %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%[[INPUT]], %[[TRANSPOSE]], %[[A]], %[[B]] : tensor<1x4x4x6xf32>, tensor<2x2x6x8xf32>, i32, i32) outs(%[[INIT]] : tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  %0 = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<2> : tensor<2xi64>}
+     ins (%input, %filter, %a, %b: tensor<1x4x4x6xf32>, tensor<8x2x2x6xf32>, i32, i32)
+    outs (%init: tensor<1x2x2x8xf32>) -> tensor<1x2x2x8xf32>
+  // CHECK:    return %[[CONV]] : tensor<1x2x2x8xf32>
+  return %0 : tensor<1x2x2x8xf32>

