[Mlir-commits] [mlir] cf8a1f6 - [mlir][tosa] Quantized Conv2DOp lowering to linalg added.

Rob Suderman llvmlistbot at llvm.org
Thu Jul 22 15:43:37 PDT 2021


Author: Rob Suderman
Date: 2021-07-22T15:42:26-07:00
New Revision: cf8a1f62083c3edbf2cd08bb16d57a70dc45722c

URL: https://github.com/llvm/llvm-project/commit/cf8a1f62083c3edbf2cd08bb16d57a70dc45722c
DIFF: https://github.com/llvm/llvm-project/commit/cf8a1f62083c3edbf2cd08bb16d57a70dc45722c.diff

LOG: [mlir][tosa] Quantized Conv2DOp lowering to linalg added.

Includes a version of a quantized conv2D operations with a lowering from TOSA
to linalg with corresponding test. We keep the quantized and quantized variants
as separate named ops to avoid the additional operations for non-quantized
convolutions.

Differential Revision: https://reviews.llvm.org/D106407

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Linalg/named-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 7e07275173a05..62f90b8629875 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -709,6 +709,121 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_input_nhwc_filter_ohwi_poly_q
+  cpp_class_name: Conv2DInputNhwcFilterOhwiPolyQOp
+  doc: |-
+    Performs a 2-D quantized convolution.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. Includes zero point
+    adjustment for quantization.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s0, s1, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: K
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s4, s5, s6, s3)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
+      -> (s0, s7, s8, s4)>
+  - !LinalgOperandDefConfig
+    name: strides
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12] -> (s9, s10)>
+  - !LinalgOperandDefConfig
+    name: dilations
+    usage: IndexAttribute
+    type_var: I64
+    attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
+      s12] -> (s11, s12)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d5, d3, d4, d6)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
+      s9, s10, s11, s12] -> (d0, d1, d2, d5)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
   cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6509df285a890..e44b2457e9250 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -862,14 +862,24 @@ convolutionMatchAndRewriterHelper(Operation *op,
   ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
 
   Type inputETy = inputTy.getElementType();
-  Type weightETy = weightTy.getElementType();
-  Type biasETy = biasTy.getElementType();
   Type resultETy = resultTy.getElementType();
 
   auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
   auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
   auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
 
+  bool isQuantized = op->hasAttr("quantization_info");
+  IntegerAttr iZp;
+  IntegerAttr kZp;
+  if (isQuantized) {
+    auto quantizationInfo =
+        op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
+    iZp = rewriter.getI32IntegerAttr(
+        quantizationInfo.input_zp().getValue().getSExtValue());
+    kZp = rewriter.getI32IntegerAttr(
+        quantizationInfo.weight_zp().getValue().getSExtValue());
+  }
+
   if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
       !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
     return rewriter.notifyMatchFailure(op,
@@ -878,11 +888,6 @@ convolutionMatchAndRewriterHelper(Operation *op,
   auto weightShape = weightTy.getShape();
   auto resultShape = resultTy.getShape();
 
-  // TODO(suderman): Support other types.
-  if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() ||
-      !resultETy.isF32())
-    return failure();
-
   // Apply padding as necessary.
   Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
   llvm::SmallVector<int64_t> pad;
@@ -924,14 +929,23 @@ convolutionMatchAndRewriterHelper(Operation *op,
   auto dilationAttr = DenseIntElementsAttr::get(
       RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
 
-  if (isa<tosa::Conv2DOp>(op)) {
+  if (isa<tosa::Conv2DOp>(op) && !isQuantized) {
     rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyOp>(
         op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast},
         strideAttr, dilationAttr);
     return success();
   }
 
-  if (isa<tosa::DepthwiseConv2DOp>(op)) {
+  if (isa<tosa::Conv2DOp>(op) && isQuantized) {
+    auto iZpVal = rewriter.create<ConstantOp>(loc, iZp);
+    auto kZpVal = rewriter.create<ConstantOp>(loc, kZp);
+    rewriter.replaceOpWithNewOp<linalg::Conv2DInputNhwcFilterOhwiPolyQOp>(
+        op, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+        ValueRange{biasBroadcast}, strideAttr, dilationAttr);
+    return success();
+  }
+
+  if (isa<tosa::DepthwiseConv2DOp>(op) && !isQuantized) {
     ShapedType linalgConvTy =
         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
                                weightShape[2], weightShape[3]},

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index ebe9822b16630..cbb2c0e312618 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -164,6 +164,30 @@ def conv_2d_input_nhwc_filter_ohwi_poly(
            D.ow * S.SW + D.kw * S.DW,
            D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic])
 
+ at linalg_structured_op
+def conv_2d_input_nhwc_filter_ohwi_poly_q(
+    I=TensorDef(T1, S.N, S.IH, S.IW, S.IC),
+    K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC),
+    IZp=ScalarDef(I32),
+    KZp=ScalarDef(I32),
+    O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True),
+    strides=AttributeDef(S.SH, S.SW),
+    dilations=AttributeDef(S.DH, S.DW)):
+  """Performs a 2-D quantized convolution.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output. Includes zero point
+  adjustment for quantization.
+  """
+  domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic)
+  O[D.n, D.oh, D.ow, D.oc] += ((cast(
+      U, I[D.n,
+           D.oh * S.SH + D.kh * S.DH,
+           D.ow * S.SW + D.kw * S.DW,
+           D.ic]) - cast(U, IZp)) *
+           (cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp)))
+
+
 @linalg_structured_op
 def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
     I=TensorDef(T1, S.N, S.IH, S.IW, S.C),

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 38a70fb5f5e49..59454f2328a59 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1184,17 +1184,21 @@ func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) {
 
 // -----
 
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
+// CHECK-LABEL: @conv2d_f32
 func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28]
-  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>)
   // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>)
   %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>)  -> (tensor<1x45x40x28xf32>)
   return
 }
 
+// -----
+
+// CHECK-LABEL: @conv2d_padded_f32
 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () {
   // CHECK: linalg.pad_tensor %arg0
   // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly
@@ -1204,6 +1208,24 @@ func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x
 
 // -----
 
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+
+// CHECK-LABEL: @conv2d_quant
+func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>, %arg2 : tensor<1024xi32>) -> () {
+  // CHECK:   %[[INIT:.+]] = linalg.init_tensor [1, 10, 10, 1024]
+  // CHECK:   %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1024xi32>) outs(%[[INIT]] : tensor<1x10x10x1024xi32>)
+  // CHECK:   ^bb0(%arg3: i32, %arg4: i32): 
+  // CHECK:     linalg.yield %arg3 : i32
+  // CHECK:   %[[C128:.+]] = constant -128 
+  // CHECK:   %[[C42:.+]] = constant 42 
+  // CHECK:   linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>) 
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x10x10x1024xi32>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 

diff  --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index db5d4c6c9977d..c873f66e2a652 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -1,5 +1,19 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
 
+// CHECK-LABEL: func @conv_2d_input_nhwc_filter_ohwi_poly_q_tensor
+func @conv_2d_input_nhwc_filter_ohwi_poly_q_tensor(%input: tensor<2x4x5x3xi8>, %filter: tensor<2x2x2x3xi8>) -> tensor<2x3x4x2xi32> {
+  %zero = constant 0 : i32
+  %init = linalg.init_tensor [2, 3, 4, 2] : tensor<2x3x4x2xi32>
+  %fill = linalg.fill(%zero, %init) : i32, tensor<2x3x4x2xi32> -> tensor<2x3x4x2xi32>
+  %c128 = constant -128 : i32
+  %c42 = constant 42 : i32
+  %0 = linalg.conv_2d_input_nhwc_filter_ohwi_poly_q
+     { dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
+     ins(%input, %filter, %c128, %c42 : tensor<2x4x5x3xi8>, tensor<2x2x2x3xi8>, i32, i32)
+    outs(%fill : tensor<2x3x4x2xi32>) -> tensor<2x3x4x2xi32>
+  return %0 : tensor<2x3x4x2xi32>
+}
+
 // CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor
 func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
   %zero = constant 0.000000e+00 : f32


        


More information about the Mlir-commits mailing list