[Mlir-commits] [mlir] [linalg] Add quantized version of `conv_3d_ncdhw_fcdhw` (PR #113953)

Felix Schneider llvmlistbot at llvm.org
Mon Oct 28 12:25:45 PDT 2024


https://github.com/ubfx created https://github.com/llvm/llvm-project/pull/113953

This patch adds the quantized 3d convolution operator `conv_3d_ncdhw_fcdhw_q`. This is the "channel-first" dimension ordering used by PyTorch and others.

>From 8ff731c832de797c8680b57c94dce67e526d3608 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Mon, 28 Oct 2024 20:22:23 +0100
Subject: [PATCH] [linalg] Add quantized version of `conv_3d_ncdhw_fcdhw`

This patch adds the quantized 3d convolution operator `conv_3d_ncdhw_fcdhw_q`.
This is the "channel-first" dimension ordering used by PyTorch and others.
---
 .../Linalg/IR/LinalgNamedStructuredOps.yaml   | 139 ++++++++++++++++++
 .../linalg/opdsl/ops/core_named_ops.py        |  40 +++++
 mlir/test/Dialect/Linalg/roundtrip.mlir       |  15 ++
 3 files changed, 194 insertions(+)

diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index bf2f26de26e9ed..4e3ef937d7d48f 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -4024,6 +4024,145 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_3d_ncdhw_fcdhw_q
+  cpp_class_name: Conv3DNcdhwFcdhwQOp
+  doc: |-
+    Performs 3-D convolution with zero point offsets.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+      s13, s14] -> (s0, s1, s2 * s3 + s4 * s5, s6 * s7 + s8 * s9, s10 * s11 + s12
+      * s13)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+      s13, s14] -> (s14, s1, s4, s8, s12)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12,
+      s13, s14] -> (s0, s14, s2, s6, s10)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12, s13, s14] -> (s3, s7, s11)>
+    default_indices:
+    - 1
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12, s13, s14] -> (s5, s9, s13)>
+    default_indices:
+    - 1
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d8, d1 * s3 + d5 * s5, d2 * s7
+      + d6 * s9, d3 * s11 + d7 * s13)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> (d4, d8, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8)[s0, s1, s2, s3, s4, s5, s6,
+      s7, s8, s9, s10, s11, s12, s13, s14] -> (d0, d4, d1, d2, d3)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_1d_nwc_wc
   cpp_class_name: DepthwiseConv1DNwcWcOp
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index b45fecd0ee1457..4c7efc8d808767 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -1126,6 +1126,46 @@ def conv_3d_ncdhw_fcdhw(
         ],
     ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
 
+ at linalg_structured_op
+def conv_3d_ncdhw_fcdhw_q(
+    I=TensorDef(
+        T1,
+        S.N,
+        S.C,
+        S.OD * S.SD + S.KD * S.DD,
+        S.OH * S.SH + S.KH * S.DH,
+        S.OW * S.SW + S.KW * S.DW,
+    ),
+    K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]),
+    dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]),
+):
+    """Performs 3-D convolution with zero point offsets.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c)
+    O[D.n, D.f, D.od, D.oh, D.ow] += (
+            TypeFn.cast_signed(
+            U,
+            I[
+                D.n,
+                D.c,
+                D.od * S.SD + D.kd * S.DD,
+                D.oh * S.SH + D.kh * S.DH,
+                D.ow * S.SW + D.kw * S.DW,
+            ],
+        ) - TypeFn.cast_signed(U, IZp)
+    ) * (
+        TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw])
+        - TypeFn.cast_signed(U, KZp)
+    )
 
 @linalg_structured_op
 def depthwise_conv_1d_nwc_wc(
diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir
index 1b8969bd115595..6e5adf007f58d7 100644
--- a/mlir/test/Dialect/Linalg/roundtrip.mlir
+++ b/mlir/test/Dialect/Linalg/roundtrip.mlir
@@ -694,3 +694,18 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt:
 // CHECK-LABEL: func @conv2d_channel_first_q_promote(
 // CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8)
 // CHECK:         linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32>
+
+// -----
+
+func.func @conv3d_channel_first_q(%img: tensor<1x27x49x48x47xi8>, %filt: tensor<28x27x3x4x5xi8>, %a: i32, %b: i32) -> tensor<1x28x47x45x43xi32> {
+  %init = arith.constant dense<0> : tensor<1x28x47x45x43xi32>
+  %1 = linalg.conv_3d_ncdhw_fcdhw_q  {dilations = dense<1> : tensor<3xi64>,
+      strides = dense<1> : tensor<3xi64>}
+    ins(%img, %filt, %a, %b : tensor<1x27x49x48x47xi8>, tensor<28x27x3x4x5xi8>, i32, i32)
+    outs(%init : tensor<1x28x47x45x43xi32>) -> tensor<1x28x47x45x43xi32>
+  return %1 : tensor<1x28x47x45x43xi32>
+}
+
+// CHECK-LABEL: func @conv3d_channel_first_q(
+// CHECK:   %[[arg0:[a-zA-z0-9]*]]: tensor<1x27x49x48x47xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<28x27x3x4x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i32, %[[arg3:[a-zA-z0-9]*]]: i32)
+// CHECK:         linalg.conv_3d_ncdhw_fcdhw_q {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<1x27x49x48x47xi8>, tensor<28x27x3x4x5xi8>, i32, i32) outs(%{{.*}} : tensor<1x28x47x45x43xi32>) -> tensor<1x28x47x45x43xi32>



More information about the Mlir-commits mailing list