[Mlir-commits] [mlir] cf8ae17 - Add Int4 support for tosa::ConstOp
Eric Kunze
llvmlistbot at llvm.org
Fri Jun 30 14:05:15 PDT 2023
Author: Jerry Ge
Date: 2023-06-30T13:57:43-07:00
New Revision: cf8ae178969f031c25c3f99ab6fe39e3ef2447e6
URL: https://github.com/llvm/llvm-project/commit/cf8ae178969f031c25c3f99ab6fe39e3ef2447e6
DIFF: https://github.com/llvm/llvm-project/commit/cf8ae178969f031c25c3f99ab6fe39e3ef2447e6.diff
LOG: Add Int4 support for tosa::ConstOp
- Also added Tosa_Weight and Tosa_WeightTensorXD specifically for weights
Signed-off-by: Jerry Ge <jerry.ge at arm.com>
Reviewed By: eric-k256
Differential Revision: https://reviews.llvm.org/D153646
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
mlir/test/Dialect/Tosa/invalid.mlir
mlir/test/Dialect/Tosa/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8c8a862e87d428..9a6370591011d5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -108,7 +108,7 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [
let arguments = (ins
Tosa_Tensor4D:$input,
- Tosa_Tensor4D:$weight,
+ 4DTensorOf<[Tosa_Weight]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_IntArrayAttr4:$pad,
@@ -140,7 +140,7 @@ def Tosa_Conv3DOp : Tosa_Op<"conv3d", [
let arguments = (ins
Tosa_Tensor5D:$input,
- Tosa_Tensor5D:$weight,
+ TensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_IntArrayAttr6:$pad,
@@ -173,7 +173,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
let arguments = (ins
Tosa_Tensor4D:$input,
- Tosa_Tensor4D:$weight,
+ 4DTensorOf<[Tosa_Weight]>:$weight,
Tosa_Tensor1D:$bias,
Tosa_IntArrayAttr4:$pad,
@@ -235,7 +235,7 @@ def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [
let arguments = (ins
Tosa_Tensor2D:$input,
- Tosa_Tensor2D:$weight,
+ 2DTensorOf<[Tosa_Weight]>:$weight,
Tosa_Tensor1D:$bias,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
);
@@ -351,7 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [
let arguments = (ins
Tosa_Tensor4D:$input,
- Tosa_Tensor4D:$filter,
+ 4DTensorOf<[Tosa_Weight]>:$filter,
Tosa_Tensor1D:$bias,
Tosa_IntArrayAttr4:$out_pad,
@@ -1817,7 +1817,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
);
let results = (outs
- Tosa_Tensor_Plus_F64:$output
+ TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output
);
let hasFolder = 1;
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 3f166cc5060a9b..e39f4662e79191 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -41,6 +41,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
def Tosa_UInt8 : UI<8>;
def Tosa_UInt16 : UI<16>;
+def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
@@ -95,10 +96,16 @@ def Tosa_Float : AnyTypeOf<[
//===----------------------------------------------------------------------===//
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float],
"number">;
+
// Add F64 type support just for tosa::CastOp and tosa::ConstOp
def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64],
"number_plus_f64">;
+// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
+// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp
+def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
+ Tosa_QuantizedInt, Tosa_Float]>;
+
//===----------------------------------------------------------------------===//
// Tensor types
//===----------------------------------------------------------------------===//
@@ -109,6 +116,7 @@ def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
// Either ranked or unranked tensor of TOSA supported element types.
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>;
+
// Must be ranked but no further constraints
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index e285a9de1d66d3..5bdcc9c1e326ab 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -20,13 +20,12 @@ func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: t
// -----
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
- // expected-error at +1 {{expect a ranked tensor for weight, got <block argument> of type 'tensor<*xi8>' at index: 1}}
+ // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}}
%0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
: (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
return %0 : tensor<1x27x27x16xi8>
}
-
// -----
func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 72f020336ff052..0ad53bd4811ac7 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -44,6 +44,16 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
return %0 : tensor<1x4x4x8xf32>
}
+// -----
+// CHECK-LABEL: conv2d_q8xi4
+func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {
+ %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
+ %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
+ %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i32: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+ return %3 : tensor<1x1x1x3xi8>
+}
+
// -----
// CHECK-LABEL: depthwise_conv2d
func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
More information about the Mlir-commits
mailing list