[Mlir-commits] [mlir] [mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs (PR #129720)

Peng Sun llvmlistbot at llvm.org
Wed Mar 5 11:46:08 PST 2025


https://github.com/psunn updated https://github.com/llvm/llvm-project/pull/129720

>From 10816560efb08c51ddd1baf71727b3883545d0d2 Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Thu, 29 Feb 2024 18:31:08 +0000
Subject: [PATCH] [TOSA] Convert RESCALE op multiplier and shift from
 attributes to inputs

This patch updates the TOSA RescaleOp by converting its multiplier and
shift parameters from attributes to explicit inputs, aligning the op
with the TOSA V1.0 specification.
Additionally, this commit adds RescaleOp-specific implementations of
inferReturnTypeComponents and verify functions.

Co-authored-by: Tai Ly <tai.ly at arm.com>
Signed-off-by: Peng Sun <peng.sun at arm.com>
Change-Id: I9e21bf757e736dabea5a2e77398e1b8a268b8ee9
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  10 +-
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  20 +-
 .../mlir/Dialect/Tosa/Utils/QuantUtils.h      |  20 ++
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  22 ++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 133 ++++++++++++-
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir  |   4 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          |  34 +++-
 mlir/test/Dialect/Tosa/availability.mlir      |   2 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |   4 +-
 mlir/test/Dialect/Tosa/invalid.mlir           | 181 ++++++++++++++++++
 mlir/test/Dialect/Tosa/level_check.mlir       |   4 +-
 mlir/test/Dialect/Tosa/ops.mlir               |   8 +-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  10 +-
 mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp |  20 +-
 14 files changed, 436 insertions(+), 36 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 85bd3fb1bb1cc..310522abdcf34 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2262,9 +2262,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
 //===----------------------------------------------------------------------===//
 // Operator: rescale
 //===----------------------------------------------------------------------===//
-def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
-      DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
-                              ["inferReturnTypeComponents"]>]> {
+def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
   let summary = "Tosa rescale operator";
 
   let description = [{
@@ -2290,10 +2288,10 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
 
   let arguments = (ins
     Tosa_Tensor:$input,
+    Tosa_1DInt16Or32Tensor:$multiplier,
+    Tosa_1DInt8Tensor:$shift,
     I32Attr:$input_zp,
     I32Attr:$output_zp,
-    DenseI32ArrayAttr:$multiplier,
-    DenseI8ArrayAttr:$shift,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel,
@@ -2310,6 +2308,8 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     Extension<[Tosa_EXT_INT16]>,
   ];
 
+  let hasVerifier = 1;
+
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
 }
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index cf6ddc66f4ada..913be435e0332 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -43,6 +43,7 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
 
 def Tosa_Int4 : I<4>;
 def Tosa_Int8 : I<8>;
+def Tosa_Int16 : I<16>;
 def Tosa_Int32 : I<32>;
 def Tosa_Int64 : I<64>;
 
@@ -54,7 +55,10 @@ def Tosa_Int : AnyTypeOf<[AnyUnsignedInteger,
                           AnySignlessInteger]>;
 
 def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,
-                   	        Tosa_Int64]>;
+                                Tosa_Int64]>;
+
+def Tosa_Int16Or32 : AnyTypeOf<[Tosa_Int16,
+                                Tosa_Int32]>;
 
 //===----------------------------------------------------------------------===//
 // Quantized Integer Types.
@@ -74,6 +78,16 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
                                    Tosa_QuantizedType<"int16", [16, 0], 1>,
                                    Tosa_QuantizedType<"int32", [32, 0], 1>]>;
 
+//===----------------------------------------------------------------------===//
+// Floating-point types.
+//===----------------------------------------------------------------------===//
+def Tosa_Float : AnyTypeOf<[F32,
+                            F16,
+                            BF16]>;
+
+def Tosa_F8 : AnyTypeOf<[F8E4M3FN,
+                         F8E5M2]>;
+
 //===----------------------------------------------------------------------===//
 // Multi-category types.
 //===----------------------------------------------------------------------===//
@@ -162,6 +176,10 @@ def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNu
 def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
 def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
 
+// 1D tensor of specific types
+def Tosa_1DInt8Tensor : 1DTensorOf<[Tosa_Int8]>;
+def Tosa_1DInt16Or32Tensor : 1DTensorOf<[Tosa_Int16Or32]>;
+
 // Ranked tensors up to given rank.
 def Tosa_Tensor1Dto4D : AnyTypeOf<[
   Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 10dc5dd36cfa9..49caac997e0d3 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -31,6 +31,26 @@ namespace tosa {
 void computeMultiplierAndShift(double scale, int32_t &multiplier,
                                int32_t &shift, int32_t scaleWidth);
 
+// Return a const value for array of IntType vec
+template <typename IntType>
+Value getConstTensorInt(OpBuilder &builder, Location loc,
+                        ArrayRef<IntType> vec) {
+  static_assert(
+      std::is_same<IntType, int8_t>::value ||
+          std::is_same<IntType, int16_t>::value ||
+          std::is_same<IntType, int32_t>::value,
+      "getConstTensorInt only supports int8_t, int16_t, and int32_t types.");
+
+  int64_t count = vec.size();
+  assert(count > 0 && "Vector must not be empty");
+  auto element_type = builder.getIntegerType(sizeof(IntType) * 8);
+  mlir::RankedTensorType const_type =
+      RankedTensorType::get({count}, element_type);
+  mlir::DenseElementsAttr const_attr = DenseElementsAttr::get(const_type, vec);
+  auto const_op = builder.create<tosa::ConstOp>(loc, const_type, const_attr);
+  return const_op.getResult();
+}
+
 //// Builds ConvOpQuantizationAttr from input and weight.
 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
                                                    Value input, Value weight);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..b7319297e0c25 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -138,7 +138,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   // tosa::MulOp
   if (isa<tosa::MulOp>(op)) {
     auto shift_val = cast<tosa::MulOp>(op).getShift();
-    ElementsAttr shift_elem;
+    DenseElementsAttr shift_elem;
     if (!shift_val.getImpl() ||
         !matchPattern(shift_val, m_Constant(&shift_elem))) {
       (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
@@ -1389,8 +1389,24 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     }
 
     // The shift and multiplier values.
-    SmallVector<int32_t> multiplierValues(op.getMultiplier());
-    SmallVector<int8_t> shiftValues(op.getShift());
+    DenseElementsAttr shiftElems;
+    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale requires constant shift input values");
+
+    DenseElementsAttr multiplierElems;
+    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale requires constant multiplier input values");
+
+    llvm::SmallVector<int8_t> shiftValues =
+        llvm::to_vector(shiftElems.getValues<int8_t>());
+    // explicit cast is required here
+    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
+        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                        [](IntegerAttr attr) -> int32_t {
+                          return static_cast<int32_t>(attr.getInt());
+                        }));
 
     // If we shift by more than the bitwidth, this just sets to 0.
     for (int i = 0, s = multiplierValues.size(); i < s; i++) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1050f3f30fe98..452c0dc4e4cce 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2024,7 +2024,6 @@ NARY_SHAPE_INFER(tosa::MinimumOp)
 NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
-NARY_SHAPE_INFER(tosa::RescaleOp)
 NARY_SHAPE_INFER(tosa::ReverseOp)
 NARY_SHAPE_INFER(tosa::RsqrtOp)
 NARY_SHAPE_INFER(tosa::SinOp)
@@ -2469,6 +2468,138 @@ LogicalResult TransposeConv2DOp::verify() {
   return success();
 }
 
+LogicalResult RescaleOp::verify() {
+  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
+  if (!inputType) {
+    emitOpError("expect shaped tensor for input, got ") << getInput().getType();
+    return failure();
+  }
+
+  auto inputElementType =
+      getStorageElementTypeOrSelf(inputType.getElementType());
+  if (!mlir::isa<IntegerType>(inputElementType)) {
+    emitOpError("expect input to have integer element type, got ")
+        << inputElementType;
+    return failure();
+  }
+
+  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
+  if (!outputType) {
+    emitOpError("expect shaped tensor for output, got ")
+        << getOutput().getType();
+    return failure();
+  }
+
+  auto outputElementType =
+      getStorageElementTypeOrSelf(outputType.getElementType());
+  if (!mlir::isa<IntegerType>(outputElementType)) {
+    emitOpError("expect output to have integer element type, got ")
+        << outputElementType;
+    return failure();
+  }
+
+  auto input_zp = getInputZpAttr().getInt();
+  if (input_zp != 0) {
+    // only int8/uint8 and uint16 input can have non-zero input_zp
+    if (!inputElementType.isInteger(8) &&
+        !(inputElementType.isInteger(16) && getInputUnsigned())) {
+      emitOpError("expect input_zp of 0, got ") << input_zp;
+      return failure();
+    }
+    // input_zp must be either 0 or 32768 for uint16 input
+    if (inputElementType.isInteger(16) && getInputUnsigned() &&
+        input_zp != 32768) {
+      emitOpError(
+          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
+          << input_zp;
+      return failure();
+    }
+  }
+
+  auto output_zp = getOutputZpAttr().getInt();
+  if (output_zp != 0) {
+    // only int8/uint8 and uint16 output can have non-zero output_zp
+    if (!outputElementType.isInteger(8) &&
+        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
+      emitOpError("expect output_zp of 0, got ") << output_zp;
+      return failure();
+    }
+    // output_zp must be either 0 or 32768 for uint16 output
+    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
+        output_zp != 32768) {
+      emitOpError(
+          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
+          << output_zp;
+      return failure();
+    }
+  }
+
+  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
+  if (!multiplierType) {
+    emitOpError("expect shaped tensor for multiplier, got ")
+        << getMultiplier().getType();
+    return failure();
+  }
+
+  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
+  if (!shiftType) {
+    emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
+    return failure();
+  }
+
+  // multiplier element type must be i32 for scale32 = true
+  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
+    emitOpError("expect i32 element type for multiplier for scale32=true, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier element type must be i16 for scale32 = false
+  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
+    emitOpError(
+        "expect i16 element type for multiplier for scale32=false, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier/shift must have shape = {numChannels},
+  // where numChannel is 1 if per_channel = false
+  // otherwise numChannel is dimension in input shape's last axis
+  int64_t numChannels = 1;
+  if (getPerChannel()) {
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    numChannels = inputShape[inputShape.size() - 1];
+  }
+
+  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+  // multiplier input has rank 1 by dialect definition
+  if (multiplierShape[0] != numChannels) {
+    emitOpError("expect shape of { ")
+        << numChannels << " } for multiplier input, got { "
+        << multiplierShape[0] << " }";
+    return failure();
+  }
+
+  ArrayRef<int64_t> shiftShape = shiftType.getShape();
+  // shift input has rank 1 by dialect definition
+  if (shiftShape[0] != numChannels) {
+    emitOpError("expect shape of { ")
+        << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
+    return failure();
+  }
+
+  return success();
+}
+
+LogicalResult RescaleOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    RescaleOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  return success();
+}
+
 LogicalResult IfOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     IfOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index afc1d5c609181..ddaaf5fcf7120 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -33,8 +33,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 
 // CHECK-LABEL: @rescale_unsupported_type
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..16a49376e58f5 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,7 +1149,9 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1178,7 +1180,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1191,17 +1195,19 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
 // CHECK-LABEL: @rescale_i8_dyn_batch
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1219,7 +1225,9 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+  %multiplier = "tosa.const"() {value = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
   return
 }
 
@@ -1247,7 +1255,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
@@ -1277,7 +1287,9 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>) -> tensor<3xi8>
+  %multiplier = "tosa.const"() {value = dense<[42, 43, 44]> : tensor<3xi16> } : () -> tensor<3xi16>
+  %shift = "tosa.const"() {value = dense<[14, 15, 64]> : tensor<3xi8> } : () -> tensor<3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
 
   // CHECK: return [[GENERIC]]
   return %0 : tensor<3xi8>
@@ -1290,7 +1302,9 @@ func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = true}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 33>, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
@@ -1299,7 +1313,9 @@ func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>)
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = false}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 98290c7b9eedd..12126d90f32b5 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -611,7 +611,7 @@ func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   // CHECK: profiles: [ [pro_int] ]
   // CHECK: extensions: [ [int16] ]
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index a0184e2d82704..5c9e58ae1bedc 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -897,7 +897,9 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
    %0 = "tosa.const"() {value = dense<127> : tensor<i8>} : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %cst0 = "tosa.const_shape"() {value = dense<[1, 1, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
    %1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.rescale %1 {double_round = true, input_zp = -128 : i32, multiplier = array<i32: 1073741824>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x1x1x1xi32>
+   %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+   %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+   %2 = tosa.rescale %1, %multiplier, %shift {double_round = true, input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
    return %2 : tensor<1x1x1x1xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index dc556f7486774..aeb1071bb695e 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1348,3 +1348,184 @@ func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (te
   %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
   return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
 }
+
+// -----
+
+func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi48> } : () -> tensor<1xi48>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1x1xi8> } : () -> tensor<1x1xi8>
+  // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
+
+// -----
+
+func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
+  %multiplier = "tosa.const"() {value = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+  // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  return %0 : tensor<13x21x3xi16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index c136e8aac9606..7610b9e2fde99 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -431,8 +431,10 @@ func.func @test_cast_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<
 // -----
 
 func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8> {
+  %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, multiplier = array<i32: 1073741824>, shift = array<i8: 30>, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
   return %0 : tensor<1x1x1x1x13x21x3xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 45a87b97125f7..4b523fd3a7b0a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -82,7 +82,9 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
   %izp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   %wzp = "tosa.const"() {value = dense<0> : tensor<1xi4>} : () -> tensor<1xi4>
   %2 = "tosa.conv2d"(%arg0, %0, %1, %izp, %wzp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi4>) -> tensor<1x1x1x3xi32>
-  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+  %multiplier = "tosa.const"() {value = dense<[2026291432, 1079222024, 1693132724]> : tensor<3xi32>} : () -> tensor<3xi32>
+  %shift = "tosa.const"() {value = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
+  %3 = tosa.rescale %2, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
 
@@ -707,7 +709,9 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
 // -----
 // CHECK-LABEL: rescale
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
-    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+   %multiplier = "tosa.const"() {value = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+   %shift = "tosa.const"() {value = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+   %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
     return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 77d77ba957621..4cdf09e1962fc 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -74,7 +74,7 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () {
 // -----
 
 // CHECK-LABEL: @test_unary_i32
-func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
+func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
   // CHECK: tosa.abs %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %0 = tosa.abs %arg0 : (tensor<4xi32>) -> tensor<*xi32>
 
@@ -93,8 +93,12 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
   // CHECK: tosa.reverse %arg0 {axis = 0 : i32} : (tensor<4xi32>) -> tensor<4xi32>
   %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
 
-  // CHECK: tosa.rescale %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
-  %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i8: 14, 15>, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<4xi32>) -> tensor<*xi16>
+  // CHECK-DAG: %[[MULT:.+]] = "tosa.const"() <{value = dense<[42, 43]> : tensor<2xi16>}> : () -> tensor<2xi16>
+  // CHECK-DAG: %[[SHIFT:.+]] = "tosa.const"() <{value = dense<[14, 15]> : tensor<2xi8>}> : () -> tensor<2xi8>
+  // CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {value = dense<[42, 43]> : tensor<2xi16>} : () -> tensor<2xi16>
+  %shift = "tosa.const"() {value = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
+  %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
 
   // CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index a97711b51ff7e..fe1422c71539a 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -165,18 +165,22 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
   // Obtain the quantized scale = multiplier and shift.
   computeMultiplierAndShift(opTensorScale, multiplier, shift, 32);
 
-  bool input_unsigned =
+  bool inputUnsigned =
       newTosaConv2DOp.getResult().getType().isUnsignedInteger();
-  bool output_unsigned = outputType.isUnsignedInteger();
+  bool outputUnsigned = outputType.isUnsignedInteger();
 
   auto newTosaRescaleOp = rewriter.create<tosa::RescaleOp>(
       op->getLoc(), outputType, newTosaConv2DOp.getResult(),
-      rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
-      rewriter.getDenseI32ArrayAttr({multiplier}),
-      rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
-      rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
-      rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned),
-      rewriter.getBoolAttr(output_unsigned));
+      getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
+      getConstTensorInt<int8_t>(rewriter, op->getLoc(),
+                                {static_cast<int8_t>(shift)}),
+      /* input_zp = */ rewriter.getI32IntegerAttr(0),
+      /* output_zp = */ rewriter.getI32IntegerAttr(outputZp),
+      /* scale32 = */ rewriter.getBoolAttr(true),
+      /* double_round = */ rewriter.getBoolAttr(true),
+      /* per_channel = */ rewriter.getBoolAttr(false),
+      rewriter.getBoolAttr(inputUnsigned),
+      rewriter.getBoolAttr(outputUnsigned));
 
   rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
   return success();



More information about the Mlir-commits mailing list