[Mlir-commits] [mlir] [mlir][linalg] Fix weight dimension ordering in 2D grouped conv (PR #73855)

Felix Schneider llvmlistbot at llvm.org
Wed Nov 29 13:24:50 PST 2023


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/73855

The `conv_2d_ngchw_fgchw` Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as `gfchw` instead of `fgchw`. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064

This patch
1) Adds a new op `conv_2d_ngchw_gfchw`
2) Fixes the dimension ordering of the old op `conv_2d_ngchw_fgchw`
3) Adds tests with non-dynamic dimensions so that it's easier to
  understand.

>From f99c556675aa8462789f74b174d0f0fcb91d03ff Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Tue, 28 Nov 2023 20:54:35 +0100
Subject: [PATCH] [mlir][linalg] Fix weight dimension ordering in 2D grouped
 conv

The `conv_2d_ngchw_fgchw` Op implements 2d grouped convolution
with dimensions ordered as given in the name. However, the current
implementation orders weights as `gfchw` instead of `fgchw`.
This was already pointed out in an old phabricator revision which
never landed: https://reviews.llvm.org/D150064

This patch
1) Adds a new op `conv_2d_ngchw_gfchw`
2) Fixes the dimension ordering of the old op `conv_2d_ngchw_fgchw`
3) Adds tests with non-dynamic dimensions so that it's easier to
  understand.
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 101 +++++++++++++++++-
 .../linalg/opdsl/ops/core_named_ops.py        |  28 ++++-
 mlir/test/Dialect/Linalg/named-ops.mlir       |  32 ++++++
 3 files changed, 159 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 12d520cd382413a..1ff6c4086cf3576 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2911,7 +2911,106 @@ structured_op: !LinalgStructuredOpConfig
     kind: output_tensor
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s11, s1, s3, s7)>
+      (s0, s1, s11, s3, s7)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s6, s10)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d2, d1, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_ngchw_gfchw
+  cpp_class_name: Conv2DNgchwGfchwOp
+  doc: |-
+    Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: GFCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s1, s11, s2, s5, s9)>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s11, s3, s7)>
   - !LinalgOperandDefConfig
     name: strides
     kind: index_attr
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 62b7da2ae2b5337..5b05364f6d35f3b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw(
         T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
     ),
     K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
-    O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
 ):
@@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw(
       * Input: NGCHW.
       * Kernel: FGCHW.
 
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
+
+
+ at linalg_structured_op
+def conv_2d_ngchw_gfchw(
+    I=TensorDef(
+        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+    ),
+    K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: GFCHW.
+
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d332..29977a71dbb8644 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
 
 // -----
 
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
+func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+  // CHECK:      linalg.conv_2d_ngchw_fgchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_ngchw_gfchw
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+  // CHECK:      linalg.conv_2d_ngchw_gfchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
 func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
   // CHECK:      %{{.+}} = linalg.conv_3d_ndhwc_dhwcf



More information about the Mlir-commits mailing list