[Mlir-commits] [mlir] 6bf0f6a - [mlir][tosa] Add quantized lowering for matmul and fully_connected

Rob Suderman llvmlistbot at llvm.org
Tue Jul 20 13:00:37 PDT 2021


Author: Rob Suderman
Date: 2021-07-20T12:58:02-07:00
New Revision: 6bf0f6a4f7d976a54ff59de4ef1c543ad2df9ff0

URL: https://github.com/llvm/llvm-project/commit/6bf0f6a4f7d976a54ff59de4ef1c543ad2df9ff0
DIFF: https://github.com/llvm/llvm-project/commit/6bf0f6a4f7d976a54ff59de4ef1c543ad2df9ff0.diff

LOG: [mlir][tosa] Add quantized lowering for matmul and fully_connected

Added the named op variants for quantized matmul and quantized batch matmul
with the necessary lowerings/tests from tosa's matmul/fully connected ops.
Current version does not use the contraction op interface as its verifiers
are not compatible with scalar operations.

Differential Revision: https://reviews.llvm.org/D105063

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 375e9028c987e..f41a4cb5703ab 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -62,6 +62,98 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: quantized_matmul
+  cpp_class_name: QuantizedMatmulOp
+  doc: |-
+    Performs a matrix multiplication of two 2D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. The quantized variant
+    includes zero-point adjustments for the left and right operands of the
+    matmul.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
+  - !LinalgOperandDefConfig
+    name: B
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)>
+  - !LinalgOperandDefConfig
+    name: AZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: BZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: C
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> ()>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> ()>
+    - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+  iterator_types:
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: A
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: AZp
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: B
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: BZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: mmt4d
   cpp_class_name: Mmt4DOp
@@ -198,6 +290,99 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: B
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: quantized_batch_matmul
+  cpp_class_name: QuantizedBatchMatmulOp
+  doc: |-
+    Performs a batched matrix multiplication of two 3D inputs.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. The quantized variant
+    includes zero-point adjustments for the left and right operands of the
+    matmul.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: A
+    usage: InputOperand
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
+  - !LinalgOperandDefConfig
+    name: B
+    usage: InputOperand
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)>
+  - !LinalgOperandDefConfig
+    name: AZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: BZp
+    usage: InputOperand
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: C
+    usage: OutputOperand
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> ()>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> ()>
+    - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: C
+    value: !ScalarExpression
+      scalar_apply:
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: C
+        - !ScalarExpression
+          scalar_apply:
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: A
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: AZp
+            - !ScalarExpression
+              scalar_apply:
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: B
+                - !ScalarExpression
+                  symbolic_cast:
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: BZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matvec
   cpp_class_name: MatvecOp

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8ed95d777ba26..f99f76d4fae74 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1019,9 +1019,24 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
         loc, outputTy.getShape(), outputTy.getElementType());
     Value zeroTensor =
         rewriter.create<linalg::FillOp>(loc, zero, initTensor).getResult(0);
-    rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
-        op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
-        ValueRange{zeroTensor});
+    if (!op.quantization_info()) {
+      rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
+          op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()},
+          ValueRange{zeroTensor});
+      return success();
+    }
+
+    auto quantizationInfo = op.quantization_info().getValue();
+    auto aZp = rewriter.create<ConstantOp>(
+        loc, rewriter.getI32IntegerAttr(
+                 quantizationInfo.a_zp().getValue().getSExtValue()));
+    auto bZp = rewriter.create<ConstantOp>(
+        loc, rewriter.getI32IntegerAttr(
+                 quantizationInfo.b_zp().getValue().getSExtValue()));
+    rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
+        op, TypeRange{op.getType()},
+        ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor);
+
     return success();
   }
 };
@@ -1040,13 +1055,8 @@ class FullyConnectedConverter
     auto bias = op.bias();
 
     auto weightTy = weight.getType().cast<ShapedType>();
-    auto biasTy = bias.getType().cast<ShapedType>();
-
     auto weightShape = weightTy.getShape();
 
-    if (op.quantization_info())
-      return failure();
-
     // Creating maps for the output of MatMul and the bias
     SmallVector<AffineMap, 4> indexingMaps;
 
@@ -1081,14 +1091,29 @@ class FullyConnectedConverter
 
     SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
     Type newWeightTy =
-        RankedTensorType::get(newWeightShape, biasTy.getElementType());
+        RankedTensorType::get(newWeightShape, weightTy.getElementType());
 
     Value transposedWeight = rewriter.create<tosa::TransposeOp>(
         loc, newWeightTy, weight, permutationValue);
 
-    rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
-        op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
-        linalgOp);
+    if (!op.quantization_info()) {
+      rewriter.replaceOpWithNewOp<linalg::MatmulOp>(
+          op, TypeRange{op.getType()}, ValueRange{input, transposedWeight},
+          linalgOp);
+      return success();
+    }
+
+    auto quantizationInfo = op.quantization_info().getValue();
+    auto inputZp = rewriter.create<ConstantOp>(
+        loc, rewriter.getI32IntegerAttr(
+                 quantizationInfo.input_zp().getValue().getSExtValue()));
+    auto outputZp = rewriter.create<ConstantOp>(
+        loc, rewriter.getI32IntegerAttr(
+                 quantizationInfo.weight_zp().getValue().getSExtValue()));
+    rewriter.replaceOpWithNewOp<linalg::QuantizedMatmulOp>(
+        op, TypeRange{op.getType()},
+        ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp);
+
     return success();
   }
 };

diff  --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index f1a2fd3ef70a9..3ea171f78ef48 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -20,6 +20,22 @@ def matmul(
   implements(ContractionOpInterface)
   C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
 
+ at linalg_structured_op
+def quantized_matmul(
+    A=TensorDef(T1, S.M, S.K),
+    B=TensorDef(T2, S.K, S.N),
+    AZp=ScalarDef(I32),
+    BZp=ScalarDef(I32),
+    C=TensorDef(U, S.M, S.N, output=True)):
+  """Performs a matrix multiplication of two 2D inputs.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output. The quantized variant
+  includes zero-point adjustments for the left and right operands of the
+  matmul.
+  """
+  domain(D.m, D.n, D.k)
+  C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.k, D.n]) - cast(U, BZp))
 
 @linalg_structured_op
 def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
@@ -40,7 +56,6 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0),
   implements(ContractionOpInterface)
   accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])
 
-
 @linalg_structured_op
 def batch_matmul(
     A=TensorDef(T1, Batch, S.M, S.K),
@@ -55,6 +70,23 @@ def batch_matmul(
   implements(ContractionOpInterface)
   C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n])
 
+ at linalg_structured_op
+def quantized_batch_matmul(
+    A=TensorDef(T1, Batch, S.M, S.K),
+    B=TensorDef(T2, Batch, S.K, S.N),
+    AZp=ScalarDef(I32),
+    BZp=ScalarDef(I32),
+    C=TensorDef(U, Batch, S.M, S.N, output=True)):
+  """Performs a batched matrix multiplication of two 3D inputs.
+
+  Numeric casting is performed on the operands to the inner multiply, promoting
+  them to the same data type as the accumulator/output. The quantized variant
+  includes zero-point adjustments for the left and right operands of the
+  matmul.
+  """
+  domain(D.b, D.m, D.n, D.k)
+  C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.b, D.k, D.n]) - cast(U, BZp))
+
 
 @linalg_structured_op
 def matvec(

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index af04fa5e9ed27..1766e8ec144cb 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -855,6 +855,21 @@ func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>, %arg2: tensor<1
 
 // -----
 
+
+// CHECK-LABEL: @matmul_quantized
+func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) {
+  // CHECK: [[C0:%.+]] = constant 0
+  // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6]
+  // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : i32, tensor<1x5x6xi32> -> tensor<1x5x6xi32>
+  // CHECK: [[ONE:%.+]] = constant 1
+  // CHECK: [[TWO:%.+]] = constant 2
+  // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
+  %0 = "tosa.matmul"(%arg0, %arg1) {quantization_info = {a_zp = 1 : i32, b_zp = 2 : i32}} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>)
+  return %0 : tensor<1x5x6xi32>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
@@ -876,6 +891,29 @@ func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: ten
 
 // -----
 
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @quantized_fully_connected
+func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
+  // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6]
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs([[INITB]] : tensor<5x6xi32>) {
+  // CHECK: ^bb0([[IN:%.+]]: i32, [[UNUSED:%.+]]: i32):
+  // CHECK:   linalg.yield [[IN]] : i32
+  // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6]
+  // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xi8>) outs([[INITT]]
+  // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
+  // CHECK:   linalg.yield [[IN]] : i8
+  // CHECK: [[ONE:%.+]] = constant 1 
+  // CHECK: [[TWO:%.+]] = constant 2 
+  // CHECK: linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[GENERIC]] : tensor<5x6xi32>) -> tensor<5x6xi32>
+  %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = 1:i32, weight_zp = 2:i32}} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>)  -> (tensor<5x6xi32>)
+  return %0 : tensor<5x6xi32>
+}
+
+// -----
+
 func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) {
   %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
   // TODO: Output contains multiple "constant 1 : index".


        


More information about the Mlir-commits mailing list