[Mlir-commits] [mlir] 86858c6 - [mlir][tosa] Add dilation to tosa.transpose_conv2d lowering

Rob Suderman llvmlistbot at llvm.org
Tue Aug 10 14:42:20 PDT 2021


Author: Rob Suderman
Date: 2021-08-10T14:36:11-07:00
New Revision: 86858c62ba033910076320df3b2ea947026ba1ea

URL: https://github.com/llvm/llvm-project/commit/86858c62ba033910076320df3b2ea947026ba1ea
DIFF: https://github.com/llvm/llvm-project/commit/86858c62ba033910076320df3b2ea947026ba1ea.diff

LOG: [mlir][tosa] Add dilation to tosa.transpose_conv2d lowering

Dilation only requires increasing the padding on the left/right side of the
input, and including dilation in the convolution. This implementation still
lacks support for strided convolutions.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D107680

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d0b6274ddc0a..37687337e10b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1029,56 +1029,49 @@ class TransposeConvConverter
     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
 
-    // We have not solved for stride / dilation yet. Dilation should be
-    // straight forward but stride is more complicated. Linalg work is likely
-    // required for efficient implementation.
-    if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
-      return failure();
-    if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
-      return failure();
-
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
-      return failure();
+    // If striding is all 1 we can modify padding and reverse the kernel along
+    // the x/y direction to make it a regular convolution. This is much simpler
+    // then handling striding....
+    if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
+      if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+          !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+        return failure();
+
+      int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
+      int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
+      int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
+      int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
+
+      llvm::SmallVector<int64_t> convPad(4, 0);
+      convPad[0] = kernelHeight - 1 - pad[0];
+      convPad[2] = kernelWidth - 1 - pad[1];
+      convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
+      convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
+
+      auto reverse1 = rewriter.create<tosa::ReverseOp>(
+          loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+      auto reverse2 = rewriter.create<tosa::ReverseOp>(
+          loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+
+      Value conv2d;
+      if (op.quantization_info().hasValue()) {
+        conv2d = rewriter.create<tosa::Conv2DOp>(
+            loc, resultTy, input, reverse2, bias,
+            rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+            rewriter.getI64ArrayAttr(dilation),
+            op.quantization_info().getValue());
+      } else {
+        conv2d = rewriter.create<tosa::Conv2DOp>(
+            loc, resultTy, input, reverse2, bias,
+            rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+            rewriter.getI64ArrayAttr(dilation));
+      }
 
-    int64_t inputHeight = inputTy.getDimSize(1);
-    int64_t inputWidth = inputTy.getDimSize(2);
-    int64_t kernelHeight = weightTy.getDimSize(1);
-    int64_t kernelWidth = weightTy.getDimSize(2);
-    int64_t outputHeight = resultTy.getDimSize(1);
-    int64_t outputWidth = resultTy.getDimSize(2);
-
-    int64_t requiredInputHeight = outputHeight + kernelHeight - 1;
-    int64_t requiredInputWidth = outputWidth + kernelWidth - 1;
-
-    llvm::SmallVector<int64_t> newPad(4, 0);
-    newPad[0] = kernelHeight - 1 - pad[0];
-    newPad[2] = kernelWidth - 1 - pad[1];
-
-    newPad[1] = requiredInputHeight - newPad[0] - inputHeight;
-    newPad[3] = requiredInputWidth - newPad[2] - inputWidth;
-
-    auto reverse1 = rewriter.create<tosa::ReverseOp>(
-        loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
-    auto reverse2 = rewriter.create<tosa::ReverseOp>(
-        loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
-
-    Value conv2d;
-    if (op.quantization_info().hasValue()) {
-      conv2d = rewriter.create<tosa::Conv2DOp>(
-          loc, resultTy, input, reverse2, bias,
-          rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
-          rewriter.getI64ArrayAttr(dilation),
-          op.quantization_info().getValue());
-    } else {
-      conv2d = rewriter.create<tosa::Conv2DOp>(
-          loc, resultTy, input, reverse2, bias,
-          rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
-          rewriter.getI64ArrayAttr(dilation));
+      rewriter.replaceOp(op, conv2d);
+      return success();
     }
 
-    rewriter.replaceOp(op, conv2d);
-    return success();
+    return failure();
   }
 };
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 554fa2d4ff15..376a9103de76 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1274,6 +1274,16 @@ func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>,
   return
 }
 
+// -----
+
+// CHECK-LABEL: @transpose_conv_dilated
+func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
+  // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
+  // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<4x3x3x2xf32>)
+  %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
+  return
+}
+
 
 // -----
 


        


More information about the Mlir-commits mailing list