[Mlir-commits] [mlir] 86858c6 - [mlir][tosa] Add dilation to tosa.transpose_conv2d lowering
Rob Suderman
llvmlistbot at llvm.org
Tue Aug 10 14:42:20 PDT 2021
Author: Rob Suderman
Date: 2021-08-10T14:36:11-07:00
New Revision: 86858c62ba033910076320df3b2ea947026ba1ea
URL: https://github.com/llvm/llvm-project/commit/86858c62ba033910076320df3b2ea947026ba1ea
DIFF: https://github.com/llvm/llvm-project/commit/86858c62ba033910076320df3b2ea947026ba1ea.diff
LOG: [mlir][tosa] Add dilation to tosa.transpose_conv2d lowering
Dilation only requires increasing the padding on the left/right side of the
input, and including dilation in the convolution. This implementation still
lacks support for strided convolutions.
Reviewed By: NatashaKnk
Differential Revision: https://reviews.llvm.org/D107680
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d0b6274ddc0a..37687337e10b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1029,56 +1029,49 @@ class TransposeConvConverter
getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
- // We have not solved for stride / dilation yet. Dilation should be
- // straight forward but stride is more complicated. Linalg work is likely
- // required for efficient implementation.
- if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
- return failure();
- if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
- return failure();
-
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
- return failure();
+ // If striding is all 1 we can modify padding and reverse the kernel along
+ // the x/y direction to make it a regular convolution. This is much simpler
+ // then handling striding....
+ if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) {
+ if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
+ !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ return failure();
+
+ int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
+ int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
+ int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
+ int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
+
+ llvm::SmallVector<int64_t> convPad(4, 0);
+ convPad[0] = kernelHeight - 1 - pad[0];
+ convPad[2] = kernelWidth - 1 - pad[1];
+ convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
+ convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
+
+ auto reverse1 = rewriter.create<tosa::ReverseOp>(
+ loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
+ auto reverse2 = rewriter.create<tosa::ReverseOp>(
+ loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
+
+ Value conv2d;
+ if (op.quantization_info().hasValue()) {
+ conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias,
+ rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+ rewriter.getI64ArrayAttr(dilation),
+ op.quantization_info().getValue());
+ } else {
+ conv2d = rewriter.create<tosa::Conv2DOp>(
+ loc, resultTy, input, reverse2, bias,
+ rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
+ rewriter.getI64ArrayAttr(dilation));
+ }
- int64_t inputHeight = inputTy.getDimSize(1);
- int64_t inputWidth = inputTy.getDimSize(2);
- int64_t kernelHeight = weightTy.getDimSize(1);
- int64_t kernelWidth = weightTy.getDimSize(2);
- int64_t outputHeight = resultTy.getDimSize(1);
- int64_t outputWidth = resultTy.getDimSize(2);
-
- int64_t requiredInputHeight = outputHeight + kernelHeight - 1;
- int64_t requiredInputWidth = outputWidth + kernelWidth - 1;
-
- llvm::SmallVector<int64_t> newPad(4, 0);
- newPad[0] = kernelHeight - 1 - pad[0];
- newPad[2] = kernelWidth - 1 - pad[1];
-
- newPad[1] = requiredInputHeight - newPad[0] - inputHeight;
- newPad[3] = requiredInputWidth - newPad[2] - inputWidth;
-
- auto reverse1 = rewriter.create<tosa::ReverseOp>(
- loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
- auto reverse2 = rewriter.create<tosa::ReverseOp>(
- loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
-
- Value conv2d;
- if (op.quantization_info().hasValue()) {
- conv2d = rewriter.create<tosa::Conv2DOp>(
- loc, resultTy, input, reverse2, bias,
- rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
- rewriter.getI64ArrayAttr(dilation),
- op.quantization_info().getValue());
- } else {
- conv2d = rewriter.create<tosa::Conv2DOp>(
- loc, resultTy, input, reverse2, bias,
- rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride),
- rewriter.getI64ArrayAttr(dilation));
+ rewriter.replaceOp(op, conv2d);
+ return success();
}
- rewriter.replaceOp(op, conv2d);
- return success();
+ return failure();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 554fa2d4ff15..376a9103de76 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1274,6 +1274,16 @@ func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>,
return
}
+// -----
+
+// CHECK-LABEL: @transpose_conv_dilated
+func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () {
+ // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0]
+ // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<4x3x3x2xf32>)
+ %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32>
+ return
+}
+
// -----
More information about the Mlir-commits
mailing list