[Mlir-commits] [mlir] [mlir][tosa] Convert tosa.transpose_conv2d to linalg.generic directly (PR #79824)

Hsiangkai Wang llvmlistbot at llvm.org
Tue Jan 30 14:48:05 PST 2024


https://github.com/Hsiangkai updated https://github.com/llvm/llvm-project/pull/79824

>From 57a99dd4b16111819dd3572b836594a4ea4af033 Mon Sep 17 00:00:00 2001
From: Hsiangkai Wang <hsiangkai.wang at arm.com>
Date: Fri, 26 Jan 2024 18:01:11 +0000
Subject: [PATCH] [mlir][tosa] Convert tosa.transpose_conv2d to linalg.generic
 directly

Currently, we use reverse, pad, reshape, and conv2d operators, etc, to
emulate transpose_conv2d. This patch adds a pattern to convert
tosa.transpose_conv2d to linalg.generic directly.
---
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  83 +++++++++++-
 .../TosaToLinalg/TosaToLinalgNamedPass.cpp    |   1 +
 .../TosaToLinalg/tosa-to-linalg-named.mlir    | 128 ++++++++++++++++++
 3 files changed, 211 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 8dc2d27bd545f..586e745c9013d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1010,6 +1010,86 @@ class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
     return success();
   }
 };
+
+class TransposeConv2DConverter
+    : public OpConversionPattern<tosa::TransposeConv2DOp> {
+public:
+  using OpConversionPattern<tosa::TransposeConv2DOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
+    if (!resultTy.hasStaticShape())
+      return rewriter.notifyMatchFailure(
+          op, "tosa.transpose_conv2d requires static shapes for result");
+
+    ArrayRef<int64_t> outShape = op.getOutShapeAttr();
+    if (outShape[0] != resultTy.getDimSize(0) ||
+        outShape[1] != resultTy.getDimSize(1) ||
+        outShape[2] != resultTy.getDimSize(2) ||
+        outShape[3] != resultTy.getDimSize(3)) {
+      return rewriter.notifyMatchFailure(
+          op, "result shape is not aligned to out_shape attribute");
+    }
+
+    Location loc = op->getLoc();
+    Value input = op->getOperand(0);
+    Value weight = op->getOperand(1);
+    Value bias = op->getOperand(2);
+
+    ShapedType inputTy = cast<ShapedType>(input.getType());
+    Type inputETy = inputTy.getElementType();
+    Type resultETy = resultTy.getElementType();
+
+    if (inputETy.isUnsignedInteger())
+      return rewriter.notifyMatchFailure(
+          op, "tosa.transpose_conv2d does not support unsigned integer input");
+
+    // Broadcast the bias as the starting values for accumulation.
+    auto emptyTensor =
+        rewriter.create<tensor::EmptyOp>(loc, resultTy.getShape(), resultETy);
+
+    Value broadcastBias =
+        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, emptyTensor);
+
+    auto *context = op->getContext();
+    AffineExpr n, ih, iw, oc, ic, kh, kw;
+    bindDims(context, n, ih, iw, oc, ic, kh, kw);
+
+    constexpr unsigned numDims = 7;
+    auto lhsMap = AffineMap::get(numDims, 0, {n, ih, iw, ic}, context);
+    auto rhsMap = AffineMap::get(numDims, 0, {oc, kh, kw, ic}, context);
+    /* outPad: top, bottom, left, right */
+    ArrayRef<int64_t> outPad = op.getOutPadAttr();
+    ArrayRef<int64_t> stride = op.getStrideAttr();
+    auto resultMap = AffineMap::get(numDims, 0,
+                                    {n, ih * stride[0] + outPad[0] + kh,
+                                     iw * stride[1] + outPad[2] + kw, oc},
+                                    context);
+
+    auto transposeConv2D =
+        rewriter
+            .create<linalg::GenericOp>(
+                loc, resultTy, ValueRange({input, weight}), broadcastBias,
+                ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap},
+                tosa::getNParallelLoopsAttrs(numDims),
+                [&](OpBuilder &nestedBuilder, Location nestedLoc,
+                    ValueRange args) {
+                  auto mul =
+                      nestedBuilder.create<arith::MulFOp>(loc, args[0], args[1])
+                          .getResult();
+                  auto acc =
+                      nestedBuilder.create<arith::AddFOp>(loc, mul, args[2])
+                          .getResult();
+                  nestedBuilder.create<linalg::YieldOp>(loc, acc);
+                })
+            .getResult(0);
+
+    rewriter.replaceOp(op, transposeConv2D);
+
+    return success();
+  }
+};
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
@@ -1031,7 +1111,8 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
       MaxPool2dConverter,
       AvgPool2dConverter,
       FullyConnectedConverter,
-      TransposeConverter
+      TransposeConverter,
+      TransposeConv2DConverter
   >(patterns->getContext());
   // clang-format on
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
index 096969391e51b..422d1e6189a21 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp
@@ -61,6 +61,7 @@ struct TosaToLinalgNamed
     target.addIllegalOp<tosa::MatMulOp>();
     target.addIllegalOp<tosa::FullyConnectedOp>();
     target.addIllegalOp<tosa::TransposeOp>();
+    target.addIllegalOp<tosa::TransposeConv2DOp>();
 
     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 6616ea7cf699f..643f2e9284032 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -781,3 +781,131 @@ func.func @test_transpose_dyn_multiple_3d(%arg0: tensor<?x?x?xf32>) {
   %1 = "tosa.transpose"(%arg0, %0) : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32>
   return
 }
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+
+// CHECK-LABEL: @transpose_conv2d
+func.func @transpose_conv2d(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1x1x2x1xf32>, %arg2: tensor<1xf32>) -> tensor<1x1x3x1xf32> {
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x1x3x1xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<1xf32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x1x3x1xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        linalg.yield %[[IN]] : f32
+  // CHECK:      } -> tensor<1x1x3x1xf32>
+  // CHECK:      %[[RESULT:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg0, %arg1 : tensor<1x1x2x1xf32>, tensor<1x1x2x1xf32>)
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x1x3x1xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[IN_0:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        %[[S3:.+]] = arith.mulf %[[IN]], %[[IN_0]] : f32
+  // CHECK:        %[[S4:.+]] = arith.addf %[[S3]], %[[OUT]] : f32
+  // CHECK:        linalg.yield %[[S4]] : f32
+  // CHECK:      } -> tensor<1x1x3x1xf32>
+  // CHECK:      return %[[RESULT]] : tensor<1x1x3x1xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 1, 3, 1>, stride = array<i64: 1, 1>} : (tensor<1x1x2x1xf32>, tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x3x1xf32>
+  return %0 : tensor<1x1x3x1xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+
+// CHECK-LABEL: @transpose_conv2d_dyn
+func.func @transpose_conv2d_dyn(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?xf32>) -> tensor<1x2x3x4xf32> {
+  // CHECK:      %[[INIT:.+]] = tensor.empty() : tensor<1x2x3x4xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<?xf32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x2x3x4xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        linalg.yield %[[IN]] : f32
+  // CHECK:      } -> tensor<1x2x3x4xf32>
+  // CHECK:      %[[RESULT:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x2x3x4xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[IN_0:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        %[[S3:.+]] = arith.mulf %[[IN]], %[[IN_0]] : f32
+  // CHECK:        %[[S4:.+]] = arith.addf %[[S3]], %[[OUT]] : f32
+  // CHECK:        linalg.yield %[[S4]] : f32
+  // CHECK:      } -> tensor<1x2x3x4xf32>
+  // CHECK:      return %[[RESULT]] : tensor<1x2x3x4xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?xf32>) -> tensor<1x2x3x4xf32>
+  return %0 : tensor<1x2x3x4xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5 + 2, d2 + d6 + 4, d3)>
+
+// CHECK-LABEL: @transpose_conv2d_dyn_with_padding
+func.func @transpose_conv2d_dyn_with_padding(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?xf32>) -> tensor<1x2x3x4xf32> {
+  // CHECK:      %[[INIT:.+]] = tensor.empty() : tensor<1x2x3x4xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<?xf32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x2x3x4xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        linalg.yield %[[IN]] : f32
+  // CHECK:      } -> tensor<1x2x3x4xf32>
+  // CHECK:      %[[RESULT:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x2x3x4xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[IN_0:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        %[[S3:.+]] = arith.mulf %[[IN]], %[[IN_0]] : f32
+  // CHECK:        %[[S4:.+]] = arith.addf %[[S3]], %[[OUT]] : f32
+  // CHECK:        linalg.yield %[[S4]] : f32
+  // CHECK:      } -> tensor<1x2x3x4xf32>
+  // CHECK:      return %[[RESULT]] : tensor<1x2x3x4xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 2, 3, 4, 5>, out_shape = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?xf32>) -> tensor<1x2x3x4xf32>
+  return %0 : tensor<1x2x3x4xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4)>
+// CHECK: #[[$MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d5, d6, d4)>
+// CHECK: #[[$MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 * 3 + d5 + 2, d2 * 3 + d6 + 4, d3)>
+
+// CHECK-LABEL: @transpose_conv2d_dyn_with_padding_and_stride
+func.func @transpose_conv2d_dyn_with_padding_and_stride(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?xf32>) -> tensor<1x2x3x4xf32> {
+  // CHECK:      %[[INIT:.+]] = tensor.empty() : tensor<1x2x3x4xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<?xf32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x2x3x4xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        linalg.yield %[[IN]] : f32
+  // CHECK:      } -> tensor<1x2x3x4xf32>
+  // CHECK:      %[[RESULT:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP2]], #[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x2x3x4xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[IN_0:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        %[[S3:.+]] = arith.mulf %[[IN]], %[[IN_0]] : f32
+  // CHECK:        %[[S4:.+]] = arith.addf %[[S3]], %[[OUT]] : f32
+  // CHECK:        linalg.yield %[[S4]] : f32
+  // CHECK:      } -> tensor<1x2x3x4xf32>
+  // CHECK:      return %[[RESULT]] : tensor<1x2x3x4xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 2, 3, 4, 5>, out_shape = array<i64: 1, 2, 3, 4>, stride = array<i64: 3, 3>} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?xf32>) -> tensor<1x2x3x4xf32>
+  return %0 : tensor<1x2x3x4xf32>
+}



More information about the Mlir-commits mailing list