[Mlir-commits] [mlir] 5685def - [mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs (#129720)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 7 07:59:34 PST 2025


Author: Peng Sun
Date: 2025-03-07T15:59:29Z
New Revision: 5685def507ed7c95bf93572e5cf8c30efbc7d99b

URL: https://github.com/llvm/llvm-project/commit/5685def507ed7c95bf93572e5cf8c30efbc7d99b
DIFF: https://github.com/llvm/llvm-project/commit/5685def507ed7c95bf93572e5cf8c30efbc7d99b.diff

LOG: [mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs (#129720)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
    mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
    mlir/test/Dialect/Tosa/availability.mlir
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/level_check.mlir
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
    mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8782d0bb55878..5340ce52d73fc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2279,9 +2279,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 //===----------------------------------------------------------------------===//
 // Operator: rescale
 //===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>]> {
+def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
   let summary = "Tosa rescale operator";
 
   let description = [{
@@ -2307,10 +2305,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
 
   let arguments = (ins
     Tosa_Tensor:$input,
+    Tosa_1DInt16Or32Tensor:$multiplier,
+    Tosa_1DInt8Tensor:$shift,
     I32Attr:$input_zp,
     I32Attr:$output_zp,
-    DenseI32ArrayAttr:$multiplier,
-    DenseI8ArrayAttr:$shift,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel,
@@ -2327,6 +2325,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     Extension<[Tosa_EXT_INT16]>,
   ];
 
+  let hasVerifier = 1;
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 

diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 7a8357ecfa430..08c0e02139b0c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 
 def Tosa_Int4 : I<4>;
 def Tosa_Int8 : I<8>;
+def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
 def Tosa_Int64 : I<64>;
 
@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
                           AnySignlessInteger]>;
 
 def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
-                   	        Tosa_Int64]>;
+                                Tosa_Int64]>;
+
+def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
+                                Tosa_Int32]>;
 
 //===----------------------------------------------------------------------===//
 // Quantized Integer Types.
@@ -163,6 +167,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
 def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
+// 1D tensor of specific types
+def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
+def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
+
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;

diff  --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index cb0df7aff35f7..bdd8713037eea 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -31,6 +31,26 @@ namespace tosa {
 bool computeMultiplierAndShift(double scale, int32_t &multiplier,
                                int32_t &shift, int32_t scaleWidth);
 
+// Return a const value for array of IntType vec
+template <typename IntType>
+Value getConstTensorInt(OpBuilder &builder, Location loc,
+                        ArrayRef<IntType> vec) {
+  static_assert(
+      std::is_same<IntType, int8_t>::value ||
+          std::is_same<IntType, int16_t>::value ||
+          std::is_same<IntType, int32_t>::value,
+      "getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
+
+  int64_t count = vec.size();
+  assert(count > 0 && "Vector must not be empty");
+  auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
+  mlir::RankedTensorType const_type =
+      RankedTensorType::get({count}, element_type);
+  mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+  auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+  return const_op.getResult();
+}
+
 //// Builds ConvOpQuantizationAttr from input and weight.
 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
                                                    Value input, Value weight);

diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a99cf293b9eac..f7dd33c7e8b53 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   // tosa::MulOp
   if (isa<tosa::MulOp>(op)) {
     auto shift_val = cast<tosa::MulOp>(op).getShift();
-    ElementsAttr shift_elem;
+    DenseElementsAttr shift_elem;
     if (!shift_val.getImpl() ||
         !matchPattern(shift_val, m_Constant(&shift_elem))) {
       (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     }
 
     // The shift and multiplier values.
-    SmallVector<int32_t> multiplierValues(op.getMultiplier());
-    SmallVector<int8_t> shiftValues(op.getShift());
+    DenseElementsAttr shiftElems;
+    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale requires constant shift input values");
+
+    DenseElementsAttr multiplierElems;
+    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale requires constant multiplier input values");
+
+    llvm::SmallVector<int8_t> shiftValues =
+        llvm::to_vector(shiftElems.getValues<int8_t>());
+    // explicit cast is required here
+    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
+        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                        [](IntegerAttr attr) -> int32_t {
+                          return static_cast<int32_t>(attr.getInt());
+                        }));
 
     // If we shift by more than the bitwidth, this just sets to 0.
     for (int i = 0, s = multiplierValues.size(); i < s; i++) {

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 5fdf5f4b5cb6a..854196250bb0c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2226,7 +2226,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
 NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
-NARY_SHAPE_INFER(tosa::RescaleOp)
 NARY_SHAPE_INFER(tosa::ReverseOp)
 NARY_SHAPE_INFER(tosa::RsqrtOp)
 NARY_SHAPE_INFER(tosa::SinOp)
@@ -2676,6 +2675,147 @@ LogicalResult TransposeConv2DOp::verify() {
   return success();
 }
 
+LogicalResult RescaleOp::verify() {
+  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
+  if (!inputType) {
+    emitOpError("expect shaped tensor for input, got ") << getInput().getType();
+    return failure();
+  }
+
+  auto inputElementType =
+      getStorageElementTypeOrSelf(inputType.getElementType());
+  if (!mlir::isa<IntegerType>(inputElementType)) {
+    emitOpError("expect input to have integer element type, got ")
+        << inputElementType;
+    return failure();
+  }
+
+  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
+  if (!outputType) {
+    emitOpError("expect shaped tensor for output, got ")
+        << getOutput().getType();
+    return failure();
+  }
+
+  auto outputElementType =
+      getStorageElementTypeOrSelf(outputType.getElementType());
+  if (!mlir::isa<IntegerType>(outputElementType)) {
+    emitOpError("expect output to have integer element type, got ")
+        << outputElementType;
+    return failure();
+  }
+
+  auto input_zp = getInputZpAttr().getInt();
+  if (input_zp != 0) {
+    // only int8/uint8 and uint16 input can have non-zero input_zp
+    if (!inputElementType.isInteger(8) &&
+        !(inputElementType.isInteger(16) && getInputUnsigned())) {
+      emitOpError("expect input_zp of 0, got ") << input_zp;
+      return failure();
+    }
+    // input_zp must be either 0 or 32768 for uint16 input
+    if (inputElementType.isInteger(16) && getInputUnsigned() &&
+        input_zp != 32768) {
+      emitOpError(
+          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
+          << input_zp;
+      return failure();
+    }
+  }
+
+  auto output_zp = getOutputZpAttr().getInt();
+  if (output_zp != 0) {
+    // only int8/uint8 and uint16 output can have non-zero output_zp
+    if (!outputElementType.isInteger(8) &&
+        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
+      emitOpError("expect output_zp of 0, got ") << output_zp;
+      return failure();
+    }
+    // output_zp must be either 0 or 32768 for uint16 output
+    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
+        output_zp != 32768) {
+      emitOpError(
+          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
+          << output_zp;
+      return failure();
+    }
+  }
+
+  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
+  if (!multiplierType) {
+    emitOpError("expect shaped tensor for multiplier, got ")
+        << getMultiplier().getType();
+    return failure();
+  }
+
+  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
+  if (!shiftType) {
+    emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
+    return failure();
+  }
+
+  // multiplier element type must be i32 for scale32 = true
+  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
+    emitOpError("expect i32 element type for multiplier for scale32=true, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier element type must be i16 for scale32 = false
+  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
+    emitOpError(
+        "expect i16 element type for multiplier for scale32=false, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  if (!inputType.hasRank())
+    return success();
+
+  // multiplier/shift must have shape = {numChannels},
+  // where numChannel is 1 if per_channel = false
+  // otherwise numChannel is dimension in input shape's last axis
+  int64_t numChannels = 1;
+  if (getPerChannel()) {
+    numChannels = inputType.getDimSize(inputType.getRank() - 1);
+  }
+
+  if (!multiplierType.hasRank())
+    return success();
+
+  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+  // multiplier input has rank 1 by dialect definition
+  if (multiplierShape[0] != ShapedType::kDynamic &&
+      multiplierShape[0] != numChannels) {
+    emitOpError("expect shape of { ")
+        << numChannels << " } for multiplier input, got { "
+        << multiplierShape[0] << " }";
+    return failure();
+  }
+
+  if (!shiftType.hasRank())
+    return success();
+
+  ArrayRef<int64_t> shiftShape = shiftType.getShape();
+  // shift input has rank 1 by dialect definition
+  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
+    emitOpError("expect shape of { ")
+        << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult RescaleOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RescaleOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  return success();
+}
+
 LogicalResult IfOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     IfOp::Adaptor adaptor,

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index cc737e6c94b48..77687b83e5e3c 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 
 // CHECK-LABEL: @rescale_unsupported_type
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9ba9965315fd3..a3ed8c2805282 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,7 +1149,9 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1178,7 +1180,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1191,17 +1195,19 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
 // CHECK-LABEL: @rescale_i8_dyn_batch
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1219,7 +1225,9 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
   return
 }
 
@@ -1247,7 +1255,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
@@ -1277,7 +1287,9 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>) -> tensor<3xi8>
+  %multiplier = "tosa.const"() {values = dense<[42, 43, 44]> : tensor<3xi16> } : () -> tensor<3xi16>
+  %shift = "tosa.const"() {values = dense<[14, 15, 64]> : tensor<3xi8> } : () -> tensor<3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
 
   // CHECK: return [[GENERIC]]
   return %0 : tensor<3xi8>
@@ -1290,7 +1302,9 @@ func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = true}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 33>, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
@@ -1299,7 +1313,9 @@ func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>)
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = false}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index e90f0af1f9c7d..1952ad79392c7 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -614,7 +614,7 @@ func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   // CHECK: profiles: [ [pro_int] ]
   // CHECK: extensions: [ [int16] ]
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a78479cc0bead..4242f68609634 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -903,7 +903,9 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
    %0 = "tosa.const"() {values = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %cst0 = "tosa.const_shape"() {values = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
    %1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1xi32>
+   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+   %2 = tosa.rescale %1, %multiplier, %shift {double_round = true, input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
    return %2 : tensor<1x1x1x1xi32>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f2956d84c9650..05700ca3765e4 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1431,3 +1431,184 @@ func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<
   %0 = tosa.argmax %arg0 {axis = 0 : i32}: (tensor<1x2x3xf32>) -> tensor<1x2x3xi32>
   return %0 : tensor<1x2x3xi32>
 }
+
+// -----
+
+func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi48> } : () -> tensor<1xi48>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1x1xi8> } : () -> tensor<1x1xi8>
+  // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}

diff  --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 8f64f217514de..bc13b614e3f9d 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -432,8 +432,10 @@ func.func @test_cast_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<
 // -----
 
 func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, multiplier = array<i32: 1073741824>, shift = array<i8: 30>, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
   return %0 : tensor<1x1x1x1x13x21x3xi8>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index f9cd342ccf1d4..e1ac7d5f51d0e 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -94,7 +94,9 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
   %izp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %wzp = "tosa.const"() {values = dense<0> : tensor<1xi4>} : () -> tensor<1xi4>
   %2 = "tosa.conv2d"(%arg0, %0, %1, %izp, %wzp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi4>) -> tensor<1x1x1x3xi32>
-  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+  %multiplier = "tosa.const"() {values = dense<[2026291432, 1079222024, 1693132724]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %shift = "tosa.const"() {values = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
+  %3 = tosa.rescale %2, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
 
@@ -718,7 +720,9 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
 // -----
 // CHECK-LABEL: rescale
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
-    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+   %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
     return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 6d480b841441a..55c5c3f6bdfb6 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -74,7 +74,7 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
 // -----
 
 // CHECK-LABEL: @test_unary_i32
-func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
+func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
   // CHECK: tosa.abs %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %0 = tosa.abs %arg0 : (tensor<4xi32>) -> tensor<*xi32>
 
@@ -93,8 +93,12 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
   // CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32>
   %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
 
-  // CHECK: tosa.rescale %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
-  %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i8: 14, 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<4xi32>) -> tensor<*xi16>
+  // CHECK-DAG: %[[MULT:.+]] = "tosa.const"() <{values = dense<[42, 43]> : tensor<2xi16>}> : () -> tensor<2xi16>
+  // CHECK-DAG: %[[SHIFT:.+]] = "tosa.const"() <{values = dense<[14, 15]> : tensor<2xi8>}> : () -> tensor<2xi8>
+  // CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<[42, 43]> : tensor<2xi16>} : () -> tensor<2xi16>
+  %shift = "tosa.const"() {values = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
+  %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
 
   // CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>

diff  --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 1297596ef193a..a5d1cf863bbf8 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -166,18 +166,22 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
   if (!computeMultiplierAndShift(opTensorScale, multiplier, shift, 32))
     return failure();
 
-  bool input_unsigned =
+  bool inputUnsigned =
       newTosaConv2DOp.getResult().getType().isUnsignedInteger();
-  bool output_unsigned = outputType.isUnsignedInteger();
+  bool outputUnsigned = outputType.isUnsignedInteger();
 
   auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
       op->getLoc(), outputType, newTosaConv2DOp.getResult(),
-      rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
-      rewriter.getDenseI32ArrayAttr({multiplier}),
-      rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
-      rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
-      rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
-      rewriter.getBoolAttr(output_unsigned));
+      getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
+      getConstTensorInt<int8_t>(rewriter, op->getLoc(),
+                                {static_cast<int8_t>(shift)}),
+      /* input_zp = */ rewriter.getI32IntegerAttr(0),
+      /* output_zp = */ rewriter.getI32IntegerAttr(outputZp),
+      /* scale32 = */ rewriter.getBoolAttr(true),
+      /* double_round = */ rewriter.getBoolAttr(true),
+      /* per_channel = */ rewriter.getBoolAttr(false),
+      rewriter.getBoolAttr(inputUnsigned),
+      rewriter.getBoolAttr(outputUnsigned));
 
   rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
   return success();


        


More information about the Mlir-commits mailing list