[Mlir-commits] [mlir] [mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs (PR #129720)
Peng Sun
llvmlistbot at llvm.org
Thu Mar 6 08:26:46 PST 2025
https://github.com/psunn updated https://github.com/llvm/llvm-project/pull/129720
>From 80119ffd61fcc18c85380318c53638956f8b1bcf Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Thu, 29 Feb 2024 18:31:08 +0000
Subject: [PATCH] [TOSA] Convert RESCALE op multiplier and shift from
attributes to inputs
This patch updates the TOSA RescaleOp by converting its multiplier and
shift parameters from attributes to explicit inputs, aligning the op
with the TOSA V1.0 specification.
Additionally, this commit adds RescaleOp-specific implementations of
inferReturnTypeComponents and verify functions.
Co-authored-by: Tai Ly <tai.ly at arm.com>
Signed-off-by: Peng Sun <peng.sun at arm.com>
Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 10 +-
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 20 +-
.../mlir/Dialect/Tosa/Utils/QuantUtils.h | 20 ++
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 22 ++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 134 ++++++++++++-
.../TosaToLinalg/tosa-to-linalg-invalid.mlir | 4 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 34 +++-
mlir/test/Dialect/Tosa/availability.mlir | 2 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 4 +-
mlir/test/Dialect/Tosa/invalid.mlir | 181 ++++++++++++++++++
mlir/test/Dialect/Tosa/level_check.mlir | 4 +-
mlir/test/Dialect/Tosa/ops.mlir | 8 +-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 10 +-
mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp | 20 +-
14 files changed, 437 insertions(+), 36 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 61178a0110aa3..8df30812c41db 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2279,9 +2279,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
//===----------------------------------------------------------------------===//
// Operator: rescale
//===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
- DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
- ["inferReturnTypeComponents"]>]> {
+def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
let summary = "Tosa rescale operator";
let description = [{
@@ -2307,10 +2305,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
let arguments = (ins
Tosa_Tensor:$input,
+ Tosa_1DInt16Or32Tensor:$multiplier,
+ Tosa_1DInt8Tensor:$shift,
I32Attr:$input_zp,
I32Attr:$output_zp,
- DenseI32ArrayAttr:$multiplier,
- DenseI8ArrayAttr:$shift,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel,
@@ -2327,6 +2325,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
Extension<[Tosa_EXT_INT16]>,
];
+ let hasVerifier = 1;
+
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 7a8357ecfa430..b1ba7d31018bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
def Tosa_Int4 : I<4>;
def Tosa_Int8 : I<8>;
+def Tosa_Int16 : I<16>;
def Tosa_Int32 : I<32>;
def Tosa_Int64 : I<64>;
@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
AnySignlessInteger]>;
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
- Tosa_Int64]>;
+ Tosa_Int64]>;
+
+def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
+ Tosa_Int32]>;
//===----------------------------------------------------------------------===//
// Quantized Integer Types.
@@ -74,6 +78,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
Tosa_QuantizedType<"int16", [16, 0], 1>,
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
+//===----------------------------------------------------------------------===//
+// Floating-point types.
+//===----------------------------------------------------------------------===//
+def Tosa_Float : AnyTypeOf<[F32,
+ F16,
+ BF16]>;
+
+def Tosa_F8 : AnyTypeOf<[F8E4M3FN,
+ F8E5M2]>;
+
//===----------------------------------------------------------------------===//
// Multi-category types.
//===----------------------------------------------------------------------===//
@@ -163,6 +177,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
+// 1D tensor of specific types
+def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
+def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
+
// Ranked tensors up to given rank.
def Tosa_Tensor1Dto4D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index cb0df7aff35f7..bdd8713037eea 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -31,6 +31,26 @@ namespace tosa {
bool computeMultiplierAndShift(double scale, int32_t &multiplier,
int32_t &shift, int32_t scaleWidth);
+// Return a const value for array of IntType vec
+template <typename IntType>
+Value getConstTensorInt(OpBuilder &builder, Location loc,
+ ArrayRef<IntType> vec) {
+ static_assert(
+ std::is_same<IntType, int8_t>::value ||
+ std::is_same<IntType, int16_t>::value ||
+ std::is_same<IntType, int32_t>::value,
+ "getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
+
+ int64_t count = vec.size();
+ assert(count > 0 && "Vector must not be empty");
+ auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
+ mlir::RankedTensorType const_type =
+ RankedTensorType::get({count}, element_type);
+ mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+ auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+ return const_op.getResult();
+}
+
//// Builds ConvOpQuantizationAttr from input and weight.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..b7319297e0c25 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
auto shift_val = cast<tosa::MulOp>(op).getShift();
- ElementsAttr shift_elem;
+ DenseElementsAttr shift_elem;
if (!shift_val.getImpl() ||
!matchPattern(shift_val, m_Constant(&shift_elem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
// The shift and multiplier values.
- SmallVector<int32_t> multiplierValues(op.getMultiplier());
- SmallVector<int8_t> shiftValues(op.getShift());
+ DenseElementsAttr shiftElems;
+ if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ return rewriter.notifyMatchFailure(
+ op, "tosa.rescale requires constant shift input values");
+
+ DenseElementsAttr multiplierElems;
+ if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ return rewriter.notifyMatchFailure(
+ op, "tosa.rescale requires constant multiplier input values");
+
+ llvm::SmallVector<int8_t> shiftValues =
+ llvm::to_vector(shiftElems.getValues<int8_t>());
+ // explicit cast is required here
+ llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ea4414fc1890e..15c1b3e8a53e6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2225,7 +2225,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
-NARY_SHAPE_INFER(tosa::RescaleOp)
NARY_SHAPE_INFER(tosa::ReverseOp)
NARY_SHAPE_INFER(tosa::RsqrtOp)
NARY_SHAPE_INFER(tosa::SinOp)
@@ -2675,6 +2674,139 @@ LogicalResult TransposeConv2DOp::verify() {
return success();
}
+LogicalResult RescaleOp::verify() {
+ auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
+ if (!inputType || !inputType.hasRank()) {
+ emitOpError("expect shaped tensor with rank for input, got ")
+ << getInput().getType();
+ return failure();
+ }
+
+ auto inputElementType =
+ getStorageElementTypeOrSelf(inputType.getElementType());
+ if (!mlir::isa<IntegerType>(inputElementType)) {
+ emitOpError("expect input to have integer element type, got ")
+ << inputElementType;
+ return failure();
+ }
+
+ auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
+ if (!outputType) {
+ emitOpError("expect shaped tensor for output, got ")
+ << getOutput().getType();
+ return failure();
+ }
+
+ auto outputElementType =
+ getStorageElementTypeOrSelf(outputType.getElementType());
+ if (!mlir::isa<IntegerType>(outputElementType)) {
+ emitOpError("expect output to have integer element type, got ")
+ << outputElementType;
+ return failure();
+ }
+
+ auto input_zp = getInputZpAttr().getInt();
+ if (input_zp != 0) {
+ // only int8/uint8 and uint16 input can have non-zero input_zp
+ if (!inputElementType.isInteger(8) &&
+ !(inputElementType.isInteger(16) && getInputUnsigned())) {
+ emitOpError("expect input_zp of 0, got ") << input_zp;
+ return failure();
+ }
+ // input_zp must be either 0 or 32768 for uint16 input
+ if (inputElementType.isInteger(16) && getInputUnsigned() &&
+ input_zp != 32768) {
+ emitOpError(
+ "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
+ << input_zp;
+ return failure();
+ }
+ }
+
+ auto output_zp = getOutputZpAttr().getInt();
+ if (output_zp != 0) {
+ // only int8/uint8 and uint16 output can have non-zero output_zp
+ if (!outputElementType.isInteger(8) &&
+ !(outputElementType.isInteger(16) && getOutputUnsigned())) {
+ emitOpError("expect output_zp of 0, got ") << output_zp;
+ return failure();
+ }
+ // output_zp must be either 0 or 32768 for uint16 output
+ if (outputElementType.isInteger(16) && getOutputUnsigned() &&
+ output_zp != 32768) {
+ emitOpError(
+ "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
+ << output_zp;
+ return failure();
+ }
+ }
+
+ auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
+ if (!multiplierType) {
+ emitOpError("expect shaped tensor for multiplier, got ")
+ << getMultiplier().getType();
+ return failure();
+ }
+
+ auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
+ if (!shiftType) {
+ emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
+ return failure();
+ }
+
+ // multiplier element type must be i32 for scale32 = true
+ if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
+ emitOpError("expect i32 element type for multiplier for scale32=true, got ")
+ << multiplierType.getElementType();
+ return failure();
+ }
+
+ // multiplier element type must be i16 for scale32 = false
+ if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
+ emitOpError(
+ "expect i16 element type for multiplier for scale32=false, got ")
+ << multiplierType.getElementType();
+ return failure();
+ }
+
+ // multiplier/shift must have shape = {numChannels},
+ // where numChannel is 1 if per_channel = false
+ // otherwise numChannel is dimension in input shape's last axis
+ int64_t numChannels = 1;
+ if (getPerChannel()) {
+ numChannels = inputType.getDimSize(inputType.getRank() - 1);
+ }
+
+ ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+ // multiplier input has rank 1 by dialect definition
+ if (multiplierShape[0] != ShapedType::kDynamic &&
+ multiplierShape[0] != numChannels) {
+ emitOpError("expect shape of { ")
+ << numChannels << " } for multiplier input, got { "
+ << multiplierShape[0] << " }";
+ return failure();
+ }
+
+ ArrayRef<int64_t> shiftShape = shiftType.getShape();
+ // shift input has rank 1 by dialect definition
+ if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
+ emitOpError("expect shape of { ")
+ << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
+ return failure();
+ }
+
+ return success();
+}
+
+LogicalResult RescaleOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ RescaleOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 08b147ac1dc1b..070c0f66d622b 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..16a49376e58f5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,7 +1149,9 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
- %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1178,7 +1180,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK: linalg.yield [[TRUNC]]
- %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1191,17 +1195,19 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-LABEL: @rescale_i8_dyn_batch
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
return
}
@@ -1219,7 +1225,9 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
- %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+ %multiplier = "tosa.const"() {value = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
return
}
@@ -1247,7 +1255,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK: linalg.yield [[TRUNC]]
- %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
return
}
@@ -1277,7 +1287,9 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
- %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>) -> tensor<3xi8>
+ %multiplier = "tosa.const"() {value = dense<[42, 43, 44]> : tensor<3xi16> } : () -> tensor<3xi16>
+ %shift = "tosa.const"() {value = dense<[14, 15, 64]> : tensor<3xi8> } : () -> tensor<3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
// CHECK: return [[GENERIC]]
return %0 : tensor<3xi8>
@@ -1290,7 +1302,9 @@ func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
// CHECK: linalg.generic
// CHECK: tosa.apply_scale
// CHECK-SAME: {double_round = true}
- %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 33>, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -1299,7 +1313,9 @@ func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>)
// CHECK: linalg.generic
// CHECK: tosa.apply_scale
// CHECK-SAME: {double_round = false}
- %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 28fe8c5003d3c..5fa25f9e7d4c8 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -614,7 +614,7 @@ func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
// CHECK: profiles: [ [pro_int] ]
// CHECK: extensions: [ [int16] ]
- %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 704ea5a0b9c3a..94978b233572c 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -903,7 +903,9 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
%0 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%cst0 = "tosa.const_shape"() {value = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
- %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1xi32>
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %2 = tosa.rescale %1, %multiplier, %shift {double_round = true, input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
return %2 : tensor<1x1x1x1xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9a5f4612db42d..732fa957bb32e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1431,3 +1431,184 @@ func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<
%0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
return %0 : tensor<1x2x3xi32>
}
+
+// -----
+
+func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi32> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi48> } : () -> tensor<1xi48>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1x1xi8> } : () -> tensor<1x1xi8>
+ // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+ %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+ // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+ return %0 : tensor<13x21x3xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index a83fad1035a6d..101cd30a9931a 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -432,8 +432,10 @@ func.func @test_cast_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<
// -----
func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8> {
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, multiplier = array<i32: 1073741824>, shift = array<i8: 30>, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
return %0 : tensor<1x1x1x1x13x21x3xi8>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 4c2cda8d9c027..36ce5860e5c99 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -94,7 +94,9 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
%izp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
%wzp = "tosa.const"() {value = dense<0> : tensor<1xi4>} : () -> tensor<1xi4>
%2 = "tosa.conv2d"(%arg0, %0, %1, %izp, %wzp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi4>) -> tensor<1x1x1x3xi32>
- %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+ %multiplier = "tosa.const"() {value = dense<[2026291432, 1079222024, 1693132724]> : tensor<3xi32>} : () -> tensor<3xi32>
+ %shift = "tosa.const"() {value = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
+ %3 = tosa.rescale %2, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
return %3 : tensor<1x1x1x3xi8>
}
@@ -711,7 +713,9 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
// -----
// CHECK-LABEL: rescale
func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
- %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 80e55912e131f..6648ecee5371e 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -74,7 +74,7 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
// -----
// CHECK-LABEL: @test_unary_i32
-func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
+func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
// CHECK: tosa.abs %arg0 : (tensor<4xi32>) -> tensor<4xi32>
%0 = tosa.abs %arg0 : (tensor<4xi32>) -> tensor<*xi32>
@@ -93,8 +93,12 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
// CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32>
%5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
- // CHECK: tosa.rescale %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
- %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i8: 14, 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<4xi32>) -> tensor<*xi16>
+ // CHECK-DAG: %[[MULT:.+]] = "tosa.const"() <{value = dense<[42, 43]> : tensor<2xi16>}> : () -> tensor<2xi16>
+ // CHECK-DAG: %[[SHIFT:.+]] = "tosa.const"() <{value = dense<[14, 15]> : tensor<2xi8>}> : () -> tensor<2xi8>
+ // CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {value = dense<[42, 43]> : tensor<2xi16>} : () -> tensor<2xi16>
+ %shift = "tosa.const"() {value = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
+ %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
// CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
%7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 1297596ef193a..a5d1cf863bbf8 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -166,18 +166,22 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
if (!computeMultiplierAndShift(opTensorScale, multiplier, shift, 32))
return failure();
- bool input_unsigned =
+ bool inputUnsigned =
newTosaConv2DOp.getResult().getType().isUnsignedInteger();
- bool output_unsigned = outputType.isUnsignedInteger();
+ bool outputUnsigned = outputType.isUnsignedInteger();
auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
op->getLoc(), outputType, newTosaConv2DOp.getResult(),
- rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
- rewriter.getDenseI32ArrayAttr({multiplier}),
- rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
- rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
- rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
- rewriter.getBoolAttr(output_unsigned));
+ getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
+ getConstTensorInt<int8_t>(rewriter, op->getLoc(),
+ {static_cast<int8_t>(shift)}),
+ /* input_zp = */ rewriter.getI32IntegerAttr(0),
+ /* output_zp = */ rewriter.getI32IntegerAttr(outputZp),
+ /* scale32 = */ rewriter.getBoolAttr(true),
+ /* double_round = */ rewriter.getBoolAttr(true),
+ /* per_channel = */ rewriter.getBoolAttr(false),
+ rewriter.getBoolAttr(inputUnsigned),
+ rewriter.getBoolAttr(outputUnsigned));
rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
return success();
More information about the Mlir-commits
mailing list