[Mlir-commits] [mlir] 422c703 - [mlir] Updated depthwise conv to support kernel dilation
Rob Suderman
llvmlistbot at llvm.org
Tue Jun 1 13:26:34 PDT 2021
Author: Rob Suderman
Date: 2021-06-01T13:25:19-07:00
New Revision: 422c7036d5fae8e9a6a30ebe4074a0cf08da1208
URL: https://github.com/llvm/llvm-project/commit/422c7036d5fae8e9a6a30ebe4074a0cf08da1208
DIFF: https://github.com/llvm/llvm-project/commit/422c7036d5fae8e9a6a30ebe4074a0cf08da1208.diff
LOG: [mlir] Updated depthwise conv to support kernel dilation
Depthwise convolution should support kernel dilation and non-dilation should
not be a special case. Updated op definition to include a dilation attribute.
This also adds a tosa.depthwise_conv2d lowering to linalg to support the new
linalg behavior.
Differential Revision: https://reviews.llvm.org/D103219
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
mlir/test/Dialect/Linalg/generalize-named-ops.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index a9299a41e081b..4b71a79d92532 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -155,7 +155,7 @@ ods_def<DepthwiseConvInputNHWCFilterHWCFOp>:
def depthwise_conv_2d_input_nhwc_filter_hwcf
(I: f32(N, IH, IW, CI), K: f32(KH, KW, CI, CO))
-> (O: f32(N, OH, OW, CI, CO))
- attr(strides: 2xi64)
+ attr(strides: 2xi64, dilations: 2xi64)
"""A general depth-wise 2-D convolution operation.
This operation performs depth-wise 2-D convolution over an input `I` and filter
@@ -164,7 +164,7 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
```
O(n, oh, ow, ci, co) = AddFOp<kh, kw>(
O(n, oh, ow, ci, co),
- MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
+ MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], ci),
K(kh, kw, ci, co)));
```
@@ -186,7 +186,7 @@ Linalg reshape op which collapses `CI` and `CO` into one dimension.
{
O(n, oh, ow, ci, co) = AddFOp<kh, kw>(
O(n, oh, ow, ci, co),
- MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
+ MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], ci),
K(kh, kw, ci, co)));
}
@@ -194,7 +194,7 @@ ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
def depthwise_conv_2d_input_nhwc_filter_hwc
(I: f32(N, IH, IW, C), K: f32(KH, KW, C))
-> (O: f32(N, OH, OW, C))
- attr(strides: 2xi64)
+ attr(strides: 2xi64, dilations: 2xi64)
"""A depth-wise 2-D convolution operation.
This operation performs depth-wise 2-D convolution over an input `I` and filter
@@ -203,7 +203,7 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
```
O(n, oh, ow, c) = AddFOp<kh, kw>(
O(n, oh, ow, c),
- MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
+ MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c),
K(kh, kw, c)));
```
@@ -223,7 +223,7 @@ Note: this op only supports channel multiplier == 1.
{
O(n, oh, ow, c) = AddFOp<kh, kw>(
O(n, oh, ow, c),
- MulFOp(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
+ MulFOp(I(n, oh * strides[0] + kh * dilations[0], ow * strides[1] + kw * dilations[1], c),
K(kh, kw, c)));
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 0d4092afe69d8..113c34304e761 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -956,9 +956,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
}
if (isa<tosa::DepthwiseConv2DOp>(op)) {
- if (llvm::any_of(dilation, [](int64_t d) { return d > 1; }))
- return failure();
-
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
@@ -969,7 +966,7 @@ convolutionMatchAndRewriterHelper(Operation *op,
Value conv = rewriter
.create<linalg::DepthwiseConvInputNHWCFilterHWCFOp>(
loc, linalgConvTy, ValueRange{input, weight},
- ValueRange{biasReshape}, strideAttr)
+ ValueRange{biasReshape}, dilationAttr, strideAttr)
.getResult(0);
Value reshape = rewriter.create<tosa::ReshapeOp>(loc, resultTy, conv);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index ec2a67fc90e40..1d2996c95fa63 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1189,7 +1189,7 @@ func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>,
// CHECK: linalg.yield %arg3 : f32
// CHECK: } -> tensor<1x5x5x33xf32>
// CHECK: [[DBIAS:%.+]] = linalg.tensor_reshape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]]
- // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
+ // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>)
// CHECK: linalg.tensor_reshape %3 {{\[}}[0], [1], [2], [3, 4]]
%2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>)
return
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index b6231927df967..7e8d1584d38dd 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -78,7 +78,7 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
- { strides = dense<1> : tensor<2xi64> }
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x3x4x2x3xf32>)
return
@@ -103,8 +103,35 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %fil
// -----
+func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x2x3x2x3xf32>) {
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+ { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+ outs(%output : memref<2x2x3x2x3xf32>)
+ return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5 * 2, d2 + d6 * 2, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+
+// CHECK: func @depthwise_conv_2d_input_nhwc_filter_hwcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x2x3x2x3xf32>)
+
+// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+
+// -----
+
func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
- linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index c5a623aa15f42..78ed723123534 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -6,11 +6,11 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
%init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
%fill = linalg.fill(%init, %zero) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
- // CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
+ // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
- { strides = dense<1> : tensor<2xi64> }
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
return %0 : tensor<2x3x4x2x3xf32>
@@ -19,11 +19,11 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
- // CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
+ // CHECK-SAME: {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
- { strides = dense<1> : tensor<2xi64> }
+ { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
outs(%output : memref<2x3x4x2x3xf32>)
return
@@ -33,10 +33,10 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32
func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
%init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwc
- // CHECK-SAME: {strides = dense<2> : vector<2xi64>}
+ // CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
- %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
+ %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: tensor<1x113x113x96xf32>, tensor<3x3x96xf32>)
outs(%init: tensor<1x56x56x96xf32>) -> tensor<1x56x56x96xf32>
return %0: tensor<1x56x56x96xf32>
@@ -45,20 +45,58 @@ func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_memref
func @depthwise_conv_2d_input_nhwc_filter_hwc_memref(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwc
- // CHECK-SAME: {strides = dense<2> : vector<2xi64>}
+ // CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x113x113x96xf32>, memref<3x3x96xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<1x56x56x96xf32>)
- linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
}
+func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
+ %zero = constant 0.000000e+00 : f32
+ %init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32>
+ %fill = linalg.fill(%init, %zero) : tensor<2x6x7x2x3xf32>, f32 -> tensor<2x6x7x2x3xf32>
+ // CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+ // CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<2x6x7x2x3xf32>)
+ %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+ { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%input, %filter : tensor<2x8x9x2xf32>, tensor<2x2x2x3xf32>)
+ outs(%fill : tensor<2x6x7x2x3xf32>) -> tensor<2x6x7x2x3xf32>
+ return %0 : tensor<2x6x7x2x3xf32>
+}
+
+// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref_dilated
+func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref_dilated(%input: memref<2x8x9x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x6x7x2x3xf32>) {
+ // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+ // CHECK-SAME: {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
+ // CHECK-SAME: outs(%{{.+}} : memref<2x6x7x2x3xf32>)
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+ { dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+ ins(%input, %filter : memref<2x8x9x2xf32>, memref<2x2x2x3xf32>)
+ outs(%output : memref<2x6x7x2x3xf32>)
+ return
+}
+
// -----
func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{missing indexing map required attribute 'strides'}}
- linalg.depthwise_conv_2d_input_nhwc_filter_hwc
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>}
+ ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
+ outs(%output: memref<1x56x56x96xf32>)
+ return
+}
+
+// -----
+
+func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
+ // expected-error @+1 {{missing indexing map required attribute 'dilations'}}
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<1> : vector<2xi64>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -68,7 +106,7 @@ func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x11
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect element type for indexing map required attribute 'strides'}}
- linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2.0> : vector<2xf32>}
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
@@ -78,7 +116,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect shape for indexing map required attribute 'strides'}}
- linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<3xi64> }
+ linalg.depthwise_conv_2d_input_nhwc_filter_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
return
More information about the Mlir-commits
mailing list