[Mlir-commits] [mlir] deebf18 - [mlir][linalg] Add pooling_nchw_max, conv_2d_nchw as yaml ops.

Tobias Gysi llvmlistbot at llvm.org
Fri Jul 23 10:37:52 PDT 2021


Author: Yi Zhang
Date: 2021-07-23T17:37:15Z
New Revision: deebf18512266e0e6917508052f6d9bbd06c7d5e

URL: https://github.com/llvm/llvm-project/commit/deebf18512266e0e6917508052f6d9bbd06c7d5e
DIFF: https://github.com/llvm/llvm-project/commit/deebf18512266e0e6917508052f6d9bbd06c7d5e.diff

LOG: [mlir][linalg] Add pooling_nchw_max, conv_2d_nchw as yaml ops.

- Add pooling_nchw_max.
- Move conv_2d_nchw to yaml ops and add strides and dilation attributes.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D106658

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Dialect/Linalg/named-ops.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 62f90b8629875..38b4619d6c178 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -905,6 +905,88 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nchw
+  cpp_class_name: Conv2DNchwOp
+  doc: |-
+    Performs 2-D convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s4, s1, s5, s6)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s0, s4, s7, s8, s1)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12] -> (s9, s10)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12] -> (s11, s12)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d1, d4, d5, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1, d2, d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              symbolic_cast:
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_sum
   cpp_class_name: PoolingNhwcSumOp
@@ -1047,6 +1129,77 @@ structured_op: !LinalgStructuredOpConfig
             - !ScalarExpression
               scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: pooling_nchw_max
+  cpp_class_name: PoolingNchwMaxOp
+  doc: |-
+    Performs max pooling.
+
+    Numeric casting is performed on the input operand, promoting it to the same
+    data type as the accumulator/output.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s4, s5)>
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s6, s7)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s8, s9)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s10, s11)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1, d2 * s8 + d4 * s10, d3 * s9 + d5 * s11)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d4, d5)>
+    - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
+      s10, s11] -> (d0, d1, d2, d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: max
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          symbolic_cast:
+            type_var: U
+            operands:
+            - !ScalarExpression
+              scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: pooling_nhwc_min
   cpp_class_name: PoolingNhwcMinOp

diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
index 2ae99b38c2a81..e792c110eab61 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc
@@ -125,12 +125,6 @@ def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F
       O(n, h, w, f), MulFOp(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
 }
 
-ods_def<ConvNCHWOp>:
-def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
-  O(n, f, h, w) = AddFOp<kh, kw>(
-      O(n, f, h, w), MulFOp(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
-}
-
 ods_def<ConvDHWOp>:
 def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
   O(d, h, w) = AddFOp<kd, kh, kw>(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 243eb621ca46a..b225a993cb5c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1186,8 +1186,8 @@ void mlir::linalg::populateConvVectorizationPatterns(
   populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
       tiling, promotion, vectorization, tileSizes);
 
-  populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
-                                               tileSizes);
+  populateVectorizationPatterns<Conv2DNchwOp, 4>(tiling, promotion,
+                                                 vectorization, tileSizes);
   populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
       tiling, promotion, vectorization, tileSizes);
 

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index cbb2c0e312618..3aa5aadc7412b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -205,6 +205,23 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
       U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
            D.c]) * cast(U, K[D.kh, D.kw, D.c])
 
+ at linalg_structured_op
+def conv_2d_nchw(
+    I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
+    K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs 2-D convolution.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output.
+  """
+  domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
+  O[D.n, D.f, D.oh, D.ow] += cast(
+      U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+           ]) * cast(U, K[D.f, D.c, D.kh, D.kw])
+
 
 @linalg_structured_op
 def pooling_nhwc_sum(
@@ -240,6 +257,22 @@ def pooling_nhwc_max(
       cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
                 D.c]))
 
+ at linalg_structured_op
+def pooling_nchw_max(
+    I=TensorDef(T1, S.N, S.C, S.H, S.W),
+    K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
+    O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs max pooling.
+
+  Numeric casting is performed on the input operand, promoting it to the same
+  data type as the accumulator/output.
+  """
+  domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
+  O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
+      cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
+                ]))
 
 @linalg_structured_op
 def pooling_nhwc_min(

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index c873f66e2a652..138d6c219dd2c 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -30,6 +30,24 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
   return %0 : tensor<2x3x4x2x3xf32>
 }
 
+// CHECK-LABEL: func @conv_2d_nchw_tensor
+func @conv_2d_nchw_tensor(%input: tensor<2x2x4x5xf32>, %filter: tensor<4x2x3x3xf32>) -> tensor<2x4x2x3xf32> {
+    %cst = constant 0.000000e+00 : f32
+    %init = linalg.init_tensor [2, 4, 2, 3] : tensor<2x4x2x3xf32>
+    %fill = linalg.fill(%cst, %init) : f32, tensor<2x4x2x3xf32> -> tensor<2x4x2x3xf32>
+// CHECK:           %{{.+}} = linalg.conv_2d_nchw
+// CHECK-SAME:       {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+// CHECK-SAME:       ins(%{{.+}}, %{{.+}} : tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>)
+// CHECK-SAME:       outs(%{{.+}} : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32>
+// CHECK:           return %{{.+}} : tensor<2x4x2x3xf32>
+// CHECK:         }
+    %0 = linalg.conv_2d_nchw
+    {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+    ins(%input, %filter: tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>)
+    outs(%fill : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32>
+    return %0 : tensor<2x4x2x3xf32>
+}
+
 // CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
 func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
   // CHECK:      linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
@@ -381,6 +399,25 @@ func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32
   return %res : tensor<1x2x2x1xf32>
 }
 
+// -----
+// CHECK-LABEL: func @pooling_nchw_max_tensor
+// CHECK:         %{{.+}} = linalg.pooling_nchw_max
+// CHECK-SAME:      dilations = dense<1> : tensor<2xi64>
+// CHECK-SAME:      strides = dense<1> : tensor<2xi64>
+// CHECK-SAME:      ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>)
+// CHECK-SAME:      outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
+
+func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> {
+  %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
+  %init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32>
+  %cst = constant 0.000000e+00 : f32
+  %fill = linalg.fill(%cst, %init) : f32, tensor<1x1x2x2xf32> -> tensor<1x1x2x2xf32>
+  %res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+    ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>)
+    outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
+  return %res : tensor<1x1x2x2xf32>
+}
+
 // -----
 
 // CHECK-LABEL: func @pooling_nhwc_max

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
index 3d40083037937..5c75aa4fc6dd6 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir
@@ -30,8 +30,10 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
 }
 
 func @conv_2d_nchw(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
-  linalg.conv_2d_nchw ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
-                     outs (%arg2: memref<?x?x?x?xf32>)
+  linalg.conv_2d_nchw
+  {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
+  ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
+  outs (%arg2: memref<?x?x?x?xf32>)
   return
 }
 


        


More information about the Mlir-commits mailing list