[Mlir-commits] [mlir] a935a0b - Adding a new variant of DepthwiseConv2D

George Petterson llvmlistbot at llvm.org
Thu Jul 21 11:37:30 PDT 2022


Author: George Petterson
Date: 2022-07-21T14:36:57-04:00
New Revision: a935a0bf5070f8c06c36eee03893226dd734ba0a

URL: https://github.com/llvm/llvm-project/commit/a935a0bf5070f8c06c36eee03893226dd734ba0a
DIFF: https://github.com/llvm/llvm-project/commit/a935a0bf5070f8c06c36eee03893226dd734ba0a.diff

LOG: Adding a new variant of DepthwiseConv2D

This is the same as the existing multiplier-1 variant of DepthwiseConv2D, but in PyTorch dim order.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D128575

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 85b0d641ba915..33583d5b1f01c 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2104,6 +2104,98 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: depthwise_conv_2d_nchw_chw
+  cpp_class_name: DepthwiseConv2DNchwChwOp
+  doc: |-
+    Performs depth-wise 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Multiplier is set to 1
+    which is a special case for most depthwise convolutions.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1 *
+      s2 + s3 * s4, s5 * s6 + s7 * s8)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s9, s3, s7)>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1, s5)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
+      s6)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
+      s8)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d0, d3, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d3, d4, d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
+      -> (d0, d3, d1, d2)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_nhwc_hwc_q
   cpp_class_name: DepthwiseConv2DNhwcHwcQOp

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 7dd3f9495a796..b22e6c3b1e415 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -482,6 +482,32 @@ def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
            D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic])
 
 
+ at linalg_structured_op
+def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH,
+                                           S.OW * S.SW + S.KW * S.DW),
+                               K=TensorDef(T2, S.IC, S.KH, S.KW),
+                               O=TensorDef(U,
+                                           S.N,
+                                           S.IC,
+                                           S.OH,
+                                           S.OW,
+                                           output=True),
+                               strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+                               dilations=IndexAttrDef(S.DH,
+                                                      S.DW,
+                                                      default=[1, 1])):
+  """Performs depth-wise 2-D convolution.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output. Multiplier is set to 1
+  which is a special case for most depthwise convolutions.
+  """
+  implements(ConvolutionOpInterface)
+  domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw)
+  O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed(
+      U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw])
+
+
 @linalg_structured_op
 def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH,
                                              S.OW * S.SW + S.KW * S.DW, S.IC),

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index e1ade46716533..9c60df33cedd3 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -95,6 +95,31 @@ func.func @depthwise_conv_2d_nhwc_hwc_memref(%input: memref<1x113x113x96xf32>, %
   return
 }
 
+// CHECK-LABEL: func @depthwise_conv_2d_nchw_chw_tensor
+func.func @depthwise_conv_2d_nchw_chw_tensor(%input: tensor<1x96x113x113xf32>, %filter: tensor<96x3x3xf32>) -> tensor<1x96x56x56xf32> {
+  %init = linalg.init_tensor [1, 96, 56, 56] : tensor<1x96x56x56xf32>
+  // CHECK:      %{{.+}} = linalg.depthwise_conv_2d_nchw_chw
+  // CHECK-SAME:   {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x96x113x113xf32>, tensor<96x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
+  %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+         ins(%input, %filter: tensor<1x96x113x113xf32>, tensor<96x3x3xf32>)
+         outs(%init: tensor<1x96x56x56xf32>) -> tensor<1x96x56x56xf32>
+  return %0: tensor<1x96x56x56xf32>
+}
+
+// CHECK-LABEL: func @depthwise_conv_2d_nchw_chw_memref
+func.func @depthwise_conv_2d_nchw_chw_memref(%input: memref<1x96x113x113xf32>, %filter: memref<96x3x3xf32>, %output: memref<1x96x56x56xf32>) {
+  // CHECK:      linalg.depthwise_conv_2d_nchw_chw
+  // CHECK-SAME:   {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<1x96x113x113xf32>, memref<96x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<1x96x56x56xf32>)
+  linalg.depthwise_conv_2d_nchw_chw {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>}
+    ins(%input, %filter: memref<1x96x113x113xf32>, memref<96x3x3xf32>)
+    outs(%output: memref<1x96x56x56xf32>)
+  return
+}
+
 func.func @depthwise_conv_2d_nhwc_hwcm_tensor_dilated(%input: tensor<2x8x9x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x6x7x2x3xf32> {
   %zero = arith.constant 0.000000e+00 : f32
   %init = linalg.init_tensor [2, 6, 7, 2, 3] : tensor<2x6x7x2x3xf32>
@@ -234,6 +259,22 @@ func.func @conv_2d_nhwc_hwcf(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x
 
 // -----
 
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw
+func.func @conv_2d_ngchw_fgchw(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
+  // CHECK:      %{{.+}} = linalg.conv_2d_ngchw_fgchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                              strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
+    outs (%init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
+  return %0 : tensor<?x?x?x?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @conv_2d_nhwc_fhwc
 func.func @conv_2d_nhwc_fhwc(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %init: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
   // CHECK:      %{{.+}} = linalg.conv_2d_nhwc_fhwc
@@ -282,6 +323,22 @@ func.func @conv_2d_nhwc_hwcf(%input: memref<?x?x?x?xf32>, %filter: memref<?x?x?x
 
 // -----
 
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw
+func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
+  // CHECK:      linalg.conv_2d_ngchw_fgchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : memref<?x?x?x?x?xf32>)
+  linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
+    outs (%output: memref<?x?x?x?x?xf32>)
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
 func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
   // CHECK:      %{{.+}} = linalg.conv_3d_ndhwc_dhwcf


        


More information about the Mlir-commits mailing list