[Mlir-commits] [mlir] cf8ae17 - Add Int4 support for tosa::ConstOp

Eric Kunze llvmlistbot at llvm.org
Fri Jun 30 14:05:15 PDT 2023


Author: Jerry Ge
Date: 2023-06-30T13:57:43-07:00
New Revision: cf8ae178969f031c25c3f99ab6fe39e3ef2447e6

URL: https://github.com/llvm/llvm-project/commit/cf8ae178969f031c25c3f99ab6fe39e3ef2447e6
DIFF: https://github.com/llvm/llvm-project/commit/cf8ae178969f031c25c3f99ab6fe39e3ef2447e6.diff

LOG: Add Int4 support for tosa::ConstOp

- Also added Tosa_Weight and Tosa_WeightTensorXD specifically for weights

Signed-off-by: Jerry Ge <jerry.ge at arm.com>

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D153646

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8c8a862e87d428..9a6370591011d5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -108,7 +108,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    Tosa_Tensor4D:$weight,
+    4DTensorOf<[Tosa_Weight]>:$weight,
     Tosa_Tensor1D:$bias,
 
     Tosa_IntArrayAttr4:$pad,
@@ -140,7 +140,7 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
 
   let arguments = (ins
     Tosa_Tensor5D:$input,
-    Tosa_Tensor5D:$weight,
+    TensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
 
     Tosa_IntArrayAttr6:$pad,
@@ -173,7 +173,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    Tosa_Tensor4D:$weight,
+    4DTensorOf<[Tosa_Weight]>:$weight,
     Tosa_Tensor1D:$bias,
 
     Tosa_IntArrayAttr4:$pad,
@@ -235,7 +235,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
 
   let arguments = (ins
     Tosa_Tensor2D:$input,
-    Tosa_Tensor2D:$weight,
+    2DTensorOf<[Tosa_Weight]>:$weight,
     Tosa_Tensor1D:$bias,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
   );
@@ -351,7 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    Tosa_Tensor4D:$filter,
+    4DTensorOf<[Tosa_Weight]>:$filter,
     Tosa_Tensor1D:$bias,
 
     Tosa_IntArrayAttr4:$out_pad,
@@ -1817,7 +1817,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
   );
 
   let results = (outs
-    Tosa_Tensor_Plus_F64:$output
+    TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
   );
   let hasFolder = 1;
 }

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 3f166cc5060a9b..e39f4662e79191 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -41,6 +41,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 def Tosa_UInt8 : UI<8>;
 def Tosa_UInt16 : UI<16>;
 
+def Tosa_Int4 : I<4>;
 def Tosa_Int8 : I<8>;
 def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
@@ -95,10 +96,16 @@ def Tosa_Float : AnyTypeOf<[
 //===----------------------------------------------------------------------===//
 def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
                                "number">;
+
 // Add F64 type support just for tosa::CastOp and tosa::ConstOp
 def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
                                "number_plus_f64">;
 
+// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
+// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
+def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
+                             Tosa_QuantizedInt, Tosa_Float]>;
+
 //===----------------------------------------------------------------------===//
 // Tensor types
 //===----------------------------------------------------------------------===//
@@ -109,6 +116,7 @@ def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
 // Either ranked or unranked tensor of TOSA supported element types.
 def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
 def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
+
 // Must be ranked but no further constraints
 def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
 

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e285a9de1d66d3..5bdcc9c1e326ab 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -20,13 +20,12 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
 // -----
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
-  // expected-error at +1 {{expect a ranked tensor for weight, got <block argument> of type 'tensor<*xi8>' at index: 1}}
+  // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
   %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
            : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
 }
 
-
 // -----
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 72f020336ff052..0ad53bd4811ac7 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -44,6 +44,16 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
     return %0 : tensor<1x4x4x8xf32>
 }
 
+// -----
+// CHECK-LABEL: conv2d_q8xi4
+func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {
+  %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
+  %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
+  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i32: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+  return %3 : tensor<1x1x1x3xi8>
+}
+
 // -----
 // CHECK-LABEL: depthwise_conv2d
 func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {


        


More information about the Mlir-commits mailing list