[Mlir-commits] [mlir] 83c56aa - [mlir][linalg] Add depthwise_conv_2d_input_nhwc_filter_hwcf to Linalg TC ops.

Hanhan Wang llvmlistbot at llvm.org
Wed Mar 3 11:47:26 PST 2021


Author: Hanhan Wang
Date: 2021-03-03T11:47:02-08:00
New Revision: 83c56aa4ee82acb63327f96f502c4778d972b412

URL: https://github.com/llvm/llvm-project/commit/83c56aa4ee82acb63327f96f502c4778d972b412
DIFF: https://github.com/llvm/llvm-project/commit/83c56aa4ee82acb63327f96f502c4778d972b412.diff

LOG: [mlir][linalg] Add depthwise_conv_2d_input_nhwc_filter_hwcf to Linalg TC ops.

Different from the definition in Tensorflow and TOSA, the output is [N,H,W,C,M]. This can make transforms easier in LinAlg because the indexing maps are plain. E.g., to determine if the fill op has dependency between the depthwise conv op, the current pipeline only recognizes the dep if they are all projected affine map.

Reviewed By: asaadaldien

Differential Revision: https://reviews.llvm.org/D97798

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/test/Dialect/Linalg/generalize-named-ops.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 37b972b73cf5..ee8a51821eea 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -151,6 +151,45 @@ def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N,
       std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
 }
 
+ods_def<DepthwiseConvInputNHWCFilterHWCFOp>:
+def depthwise_conv_2d_input_nhwc_filter_hwcf
+      (I: f32(N, IH, IW, CI), K: f32(KH, KW, CI, CO))
+   -> (O: f32(N, OH, OW, CI, CO))
+  attr(strides: 2xi64)
+"""A general depth-wise 2-D convolution operation.
+
+This operation performs depth-wise 2-D convolution over an input `I` and filter
+`F` and generates output `O` using the following computation:
+
+```
+  O(n, oh, ow, ci, co) = std_addf<kh, kw>(
+      O(n, oh, ow, ci, co),
+      std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
+               K(kh, kw, ci, co)));
+```
+
+where
+
+* `I` is a 4-D tensor with shape `(N, IH, IW, CI)`.
+* `F` is a 4-D tensor with shape `(KH, KW, CI, CO)`.
+* `O` is a 5-D tensor with shape `(N, OH, OW, CI, CO)`.
+* `strides` is a 2-element vector attribute for window strides along the
+  height/width dimension.
+
+The indexing maps for these three tensors contain 7 dimensions, following the
+order of (`N`, `OH`, `OW`, `CI`, `CO`, `KH`, `KW`).
+
+Note: this op only supports any channel multiplier, which is `CO`. To map back
+to 4D result as DepthwiseConvInputNHWCFilterHWCOp, you will have to create a
+Linalg reshape op which collapses `CI` and `CO` into one dimension.
+"""
+{
+  O(n, oh, ow, ci, co) = std_addf<kh, kw>(
+      O(n, oh, ow, ci, co),
+      std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
+               K(kh, kw, ci, co)));
+}
+
 ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
 def depthwise_conv_2d_input_nhwc_filter_hwc
       (I: f32(N, IH, IW, C), K: f32(KH, KW, C))
@@ -162,8 +201,10 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
 `F` and generates output `O` using the following computation:
 
 ```
-O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
-  I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)))
+O(n, oh, ow, c) = std_addf<kh, kw>(
+    O(n, oh, ow, c),
+    std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
+             K(kh, kw, c)));
 ```
 
 where

diff  --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index f28b8a801fd8..90b75abaca7d 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -76,6 +76,33 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
 
 // -----
 
+func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
+  linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+     { strides = dense<1> : tensor<2xi64> }
+     ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+    outs(%output : memref<2x3x4x2x3xf32>)
+  return
+}
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
+
+// CHECK: func @depthwise_conv_2d_input_nhwc_filter_hwcf
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
+// CHECK-NEXT:      %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
+// CHECK-NEXT:      %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
+// CHECK-NEXT:      linalg.yield %[[ADD]] : f32
+
+// -----
+
 func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
   linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
     ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 66fe9afba451..499acc602b51 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,34 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor
+func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
+  %zero = constant 0.000000e+00 : f32
+  %init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
+  %fill = linalg.fill(%init, %zero) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
+  // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+  // CHECK-SAME:   {strides = dense<1> : tensor<2xi64>}
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
+  %0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+     { strides = dense<1> : tensor<2xi64> }
+     ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
+    outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
+  return %0 : tensor<2x3x4x2x3xf32>
+}
+
+// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
+func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
+  // CHECK:      linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+  // CHECK-SAME:   {strides = dense<1> : tensor<2xi64>}
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<2x3x4x2x3xf32>)
+  linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
+     { strides = dense<1> : tensor<2xi64> }
+     ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
+    outs(%output : memref<2x3x4x2x3xf32>)
+  return
+}
+
 // CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor
 func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
   %init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>


        


More information about the Mlir-commits mailing list