[Mlir-commits] [mlir] [mlir][tosa] Change Rescale zero points to be inputs (PR #130340)

Tai Ly llvmlistbot at llvm.org
Mon Mar 10 09:45:54 PDT 2025


https://github.com/Tai78641 updated https://github.com/llvm/llvm-project/pull/130340

>From 8af23aeb287f5193c7cb30826147c18df0b82b88 Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Sat, 25 Jan 2025 18:49:35 +0000
Subject: [PATCH] [mlir][tosa] Change Rescale zero points to be inputs

 *Update RescaleOp to use zero-point as operands instead of attributes.
 *Check input_zp data type against the input and output_zp data type
   against the output.

Change-Id: I2cf0106eb9f9ec88e16de5efc93b651053e5fc92
Signed-off-by: Peng Sun <peng.sun at arm.com>
---
 .../Dialect/Tosa/IR/TosaComplianceData.h.inc  | 23 +++--
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 11 ++-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 30 ++++--
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 96 ++++++++++++-------
 .../Tosa/Transforms/TosaProfileCompliance.cpp |  2 +
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir  |  5 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          | 69 ++++++++-----
 mlir/test/Dialect/Tosa/availability.mlir      | 12 ++-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  4 +-
 mlir/test/Dialect/Tosa/invalid.mlir           | 72 ++++++++++----
 mlir/test/Dialect/Tosa/level_check.mlir       |  8 +-
 mlir/test/Dialect/Tosa/ops.mlir               |  8 +-
 .../Tosa/profile_pro_int_unsupported.mlir     | 15 +++
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  8 +-
 mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp | 13 ++-
 15 files changed, 263 insertions(+), 113 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index d3fd4c3d1d3e1..bfdaf14f6b05d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -218,15 +218,15 @@ profileComplianceMap = {
         {fp32T, fp16T}}}}},
     {"tosa.rescale",
      {{{Profile::pro_int},
-       {{i8T, i8T},
-        {i8T, i16T},
-        {i8T, i32T},
-        {i16T, i8T},
-        {i16T, i16T},
-        {i16T, i32T},
-        {i32T, i8T},
-        {i32T, i16T},
-        {i32T, i32T}}}}},
+       {{i8T, i8T, i8T, i8T},
+        {i8T, i8T, i16T, i16T},
+        {i8T, i8T, i32T, i32T},
+        {i16T, i16T, i8T, i8T},
+        {i16T, i16T, i16T, i16T},
+        {i16T, i16T, i32T, i32T},
+        {i32T, i32T, i8T, i8T},
+        {i32T, i32T, i16T, i16T},
+        {i32T, i32T, i32T, i32T}}}}},
     {"tosa.const",
      {{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
       {{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -386,7 +386,10 @@ extensionComplianceMap = {
         {fp16T, fp8e5m2T},
         {fp32T, fp8e5m2T}}}}},
     {"tosa.rescale",
-     {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+     {{{Extension::int16},
+       {{i48T, i48T, i8T, i8T},
+        {i48T, i48T, i16T, i16T},
+        {i48T, i48T, i32T, i32T}}}}},
     {"tosa.const",
      {{{Extension::int4}, {{i4T}}},
       {{Extension::int16}, {{i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ecddc9fe9a13d..73d1f0b2b8932 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2344,8 +2344,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     Tosa_Tensor:$input,
     Tosa_1DInt16Or32Tensor:$multiplier,
     Tosa_1DInt8Tensor:$shift,
-    I32Attr:$input_zp,
-    I32Attr:$output_zp,
+    Tosa_ScalarIntOrFloatTensor:$input_zp,
+    Tosa_ScalarIntOrFloatTensor:$output_zp,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel,
@@ -2362,6 +2362,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     Extension<[Tosa_EXT_INT16]>,
   ];
 
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getInputZeroPoint();
+    FailureOr<int64_t> getOutputZeroPoint();
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
+
   let hasVerifier = 1;
 
   let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..89af54132f820 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
 
 template <typename T>
 static arith::ConstantOp
-createConstFromIntAttribute(Operation *op, const std::string &attrName,
-                            Type requiredAttrType, OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(
-      cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
+createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
+                       OpBuilder &rewriter) {
+  auto castedN = static_cast<T>(zp);
   return rewriter.create<arith::ConstantOp>(
       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
 }
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           // later.
           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
 
-          auto inputZp = createConstFromIntAttribute<int32_t>(
-              op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+          FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+          if (failed(maybeIZp)) {
+            (void)rewriter.notifyMatchFailure(
+                op, "input zero point cannot be statically determined");
+            return;
+          }
+
+          auto inputZp = createConstOpFromZpVal<int32_t>(
+              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
               nestedBuilder);
-          auto outputZp = createConstFromIntAttribute<int32_t>(
-              op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+          FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+          if (failed(maybeOZp)) {
+            (void)rewriter.notifyMatchFailure(
+                op, "output zero point cannot be statically determined");
+            return;
+          };
+
+          auto outputZp = createConstOpFromZpVal<int32_t>(
+              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7a991b3876f69..8869d92b14a70 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -254,6 +254,25 @@ static Type getStorageElementTypeOrSelf(Type type) {
   return elementType;
 }
 
+static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
+                                                  Value valZp, StringRef name) {
+  Type eType = getStorageElementTypeOrSelf(val.getType());
+  Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
+
+  bool bothInts =
+      mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
+  bool sameBitWidth =
+      (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
+
+  if (!bothInts || !sameBitWidth) {
+    return op->emitOpError()
+           << "expected " << name << " and " << name
+           << "_zp to both be integer of the same bitwidth, but got " << eType
+           << " vs. " << eZpType;
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -1708,6 +1727,33 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
   return success();
 }
 
+static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
+                                     const int64_t &zp,
+                                     const std::string &operand) {
+  bool isInputZp = (operand == "Input");
+
+  bool tensorUnsigned =
+      isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+  StringRef tensorName = isInputZp ? "input" : "output";
+
+  Type zpElemType = getElementTypeOrSelf(zpVal);
+
+  if (zp != 0) {
+    if (!zpElemType.isInteger(8) &&
+        !(zpElemType.isInteger(16) && tensorUnsigned)) {
+      return op.emitOpError()
+             << "expect " << tensorName << "_zp of 0, got " << zp;
+    }
+    if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+      return op.emitOpError() << "expect " << tensorName
+                              << "_zp of 0 or 32768 for unsigned int16 "
+                              << tensorName << ", got " << zp;
+    }
+  }
+
+  return success();
+}
+
 #define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
   FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() {                \
     return getZeroPoint(*this, get##OPERAND_NAME##Zp());                       \
@@ -1728,6 +1774,8 @@ ZERO_POINT_HELPER(AvgPool2dOp, Input)
 ZERO_POINT_HELPER(AvgPool2dOp, Output)
 ZERO_POINT_HELPER(MatMulOp, A)
 ZERO_POINT_HELPER(MatMulOp, B)
+ZERO_POINT_HELPER(RescaleOp, Input)
+ZERO_POINT_HELPER(RescaleOp, Output)
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2712,41 +2760,21 @@ LogicalResult RescaleOp::verify() {
     return failure();
   }
 
-  auto input_zp = getInputZpAttr().getInt();
-  if (input_zp != 0) {
-    // only int8/uint8 and uint16 input can have non-zero input_zp
-    if (!inputElementType.isInteger(8) &&
-        !(inputElementType.isInteger(16) && getInputUnsigned())) {
-      emitOpError("expect input_zp of 0, got ") << input_zp;
-      return failure();
-    }
-    // input_zp must be either 0 or 32768 for uint16 input
-    if (inputElementType.isInteger(16) && getInputUnsigned() &&
-        input_zp != 32768) {
-      emitOpError(
-          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
-          << input_zp;
-      return failure();
-    }
-  }
+  if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
+          .failed())
+    return failure();
 
-  auto output_zp = getOutputZpAttr().getInt();
-  if (output_zp != 0) {
-    // only int8/uint8 and uint16 output can have non-zero output_zp
-    if (!outputElementType.isInteger(8) &&
-        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
-      emitOpError("expect output_zp of 0, got ") << output_zp;
-      return failure();
-    }
-    // output_zp must be either 0 or 32768 for uint16 output
-    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
-        output_zp != 32768) {
-      emitOpError(
-          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
-          << output_zp;
-      return failure();
-    }
-  }
+  if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
+          .failed())
+    return failure();
+
+  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
+  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+    return failure();
+
+  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+    return failure();
 
   auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
   if (!multiplierType) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 983062ffd7912..08e4bfa5c2548 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
 template <>
 void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
   addValue(op.getInput());
+  addValue(op.getInputZp());
+  addValue(op.getOutputZp());
   addValue(op.getOutput());
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..842b50e804cbe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..49d3e86e8fcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,9 +1149,11 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1182,7 +1184,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1195,19 +1199,22 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
 // CHECK-LABEL: @rescale_i8_dyn_batch
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1219,15 +1226,19 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
 // CHECK-LABEL: @rescale_dyn
 // CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
 func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32>} : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: %[[C1:.+]] = arith.constant 1
   // CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
   // CHECK: %[[C2:.+]] = arith.constant 2
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
-  %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+
   return
 }
 
@@ -1257,7 +1268,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
@@ -1274,8 +1287,8 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]], [[MULTIPLIERS]], [[SHIFTS]] : tensor<3xi8>, tensor<3xi32>, tensor<3xi8>) outs([[INIT]] : tensor<3xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C243:%.+]] = arith.constant 243
-  // CHECK: [[C252:%.+]] = arith.constant 252
+  // CHECK: [[C243:%.+]] = arith.constant 43
+  // CHECK: [[C252:%.+]] = arith.constant 52
 
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C243]]
@@ -1287,9 +1300,11 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %multiplier = "tosa.const"() {values = dense<[42, 43, 44]> : tensor<3xi16> } : () -> tensor<3xi16>
-  %shift = "tosa.const"() {values = dense<[14, 15, 64]> : tensor<3xi8> } : () -> tensor<3xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
+  %multiplier = "tosa.const"() {values = dense<[42, 43, 44]> : tensor<3xi16>} : () -> tensor<3xi16>
+  %shift = "tosa.const"() {values = dense<[14, 15, 64]> : tensor<3xi8>} : () -> tensor<3xi8>
+  %input_zp = "tosa.const"() {values = dense<43> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<52> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<3xi8>
 
   // CHECK: return [[GENERIC]]
   return %0 : tensor<3xi8>
@@ -1299,23 +1314,29 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
 
 // CHECK-LABEL: @rescaleDoubleRound
 func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32>} : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<33> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<43> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<52> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
-  // CHECK-SAME:  {double_round = true}
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
-  %shift = "tosa.const"() {values = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+  // CHECK-SAME: {double_round = true}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
 // CHECK-LABEL: @rescaleUnnecessaryDoubleRound
 func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32>} : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<43> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<52> : tensor<1xi8>} : () -> tensor<1xi8>
+
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = false}
-  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
-  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index b786264d84106..c1bd54c592c31 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -610,16 +610,18 @@ func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
 }
 
 // -----
-// CHECK-LABEL: rescale
+// CHECK-LABEL: test_rescale
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
-  // CHECK: profiles: [ [pro_int] ]
-  // CHECK: extensions: [ [int16] ]
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+  // CHECK: tosa.rescale profiles: [ [pro_int] ]
+  // CHECK: tosa.rescale extensions: [ [int16] ]
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
 // -----
-// CHECK-LABEL: const
+// CHECK-LABEL: test_const
 func.func @test_const(%arg0 : index) -> tensor<4xi32> {
   // CHECK: profiles: [ [pro_int, pro_fp] ]
   // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 4242f68609634..6ef57a5438c14 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -905,7 +905,9 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
    %1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
    %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
-   %2 = tosa.rescale %1, %multiplier, %shift {double_round = true, input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
+   %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+   %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+   %2 = tosa.rescale %1, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x1x1x1xi32>
    return %2 : tensor<1x1x1x1xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index f536444f6379e..2fdb9b4fc9b0c 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1437,8 +1437,10 @@ func.func @test_argmax_invalid_output_shape(%arg0: tensor<1x2x3xf32>) -> tensor<
 func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xi32> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() {values = dense<0.0> : tensor<1xf32>} : () -> tensor<1xf32>
   // expected-error at +1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
 
@@ -1447,8 +1449,10 @@ func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor
 func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1457,8 +1461,10 @@ func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tenso
 func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi48> } : () -> tensor<1xi48>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1467,8 +1473,10 @@ func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> t
 func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
+  %input_zp = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1477,8 +1485,10 @@ func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor
 func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
 
@@ -1487,8 +1497,10 @@ func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tens
 func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<1> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1497,8 +1509,10 @@ func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tens
 func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<1> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1508,8 +1522,10 @@ func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tens
 func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
   // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
 
@@ -1518,8 +1534,10 @@ func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> ten
 func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1528,8 +1546,10 @@ func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> ten
 func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1538,8 +1558,10 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
 func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1548,8 +1570,10 @@ func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> te
 func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1558,8 +1582,10 @@ func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> te
 func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1568,8 +1594,10 @@ func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> t
 func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1x1xi8> } : () -> tensor<1x1xi8>
+  %input_zp = "tosa.const"() {values = dense<1> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1578,8 +1606,10 @@ func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor
 func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1588,8 +1618,10 @@ func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x
 func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1598,8 +1630,10 @@ func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13
 func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1608,8 +1642,10 @@ func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16
 func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {double_round = false, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index 6d8237635d0ec..2899214a3baf4 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -432,10 +432,12 @@ func.func @test_cast_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi32>) -> tensor<
 // -----
 
 func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tensor<1x1x1x1x13x21x3xi8> {
-  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
-  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32>} : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
   return %0 : tensor<1x1x1x1x13x21x3xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 480d8c327ab83..1baf28a7fe986 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -96,7 +96,9 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
   %2 = "tosa.conv2d"(%arg0, %0, %1, %izp, %wzp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi4>) -> tensor<1x1x1x3xi32>
   %multiplier = "tosa.const"() {values = dense<[2026291432, 1079222024, 1693132724]> : tensor<3xi32>} : () -> tensor<3xi32>
   %shift = "tosa.const"() {values = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
-  %3 = tosa.rescale %2, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
+  %rescale_input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %rescale_output_zp = "tosa.const"() <{values = dense<27> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %3 = tosa.rescale %2, %multiplier, %shift, %rescale_input_zp, %rescale_output_zp {scale32 = true, double_round = true, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
 
@@ -724,7 +726,9 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
    %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
    %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
-   %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+   %input_zp = "tosa.const"() <{values = dense<127> : tensor<1xi8>}> : () -> tensor<1xi8>
+   %output_zp = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8>
+   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
     return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 1d6d33b9a02c7..58f452361cd56 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -24,3 +24,18 @@ func.func @test_cast_i8_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
   %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
+
+// -----
+func.func @test_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+  // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op illegal: requires [pro_int] but not enabled in target}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index deede4b0afadc..f4e797ef49335 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -95,10 +95,14 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
 
   // CHECK-DAG: %[[MULT:.+]] = "tosa.const"() <{values = dense<[42, 43]> : tensor<2xi16>}> : () -> tensor<2xi16>
   // CHECK-DAG: %[[SHIFT:.+]] = "tosa.const"() <{values = dense<[14, 15]> : tensor<2xi8>}> : () -> tensor<2xi8>
-  // CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<2xi8>
+  // CHECK-DAG: %[[INPUTZP:.+]] = "tosa.const"() <{values = dense<43> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK-DAG: %[[OUTPUTZP:.+]] = "tosa.const"() <{values = dense<52> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]], %[[INPUTZP]], %[[OUTPUTZP]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
   %multiplier = "tosa.const"() {values = dense<[42, 43]> : tensor<2xi16>} : () -> tensor<2xi16>
   %shift = "tosa.const"() {values = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
-  %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
+  %input_zp = "tosa.const"() {values = dense<43> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<52> : tensor<1xi8>} : () -> tensor<1xi8>
+  %6 = tosa.rescale %arg1, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index a5d1cf863bbf8..9eb56c86c6fdf 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -155,7 +155,15 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
   double inputScale = inputQType.getScale();
   double weightScale = weightQType.getScale();
   double outputScale = outputQType.getScale();
-  int64_t outputZp = outputQType.getZeroPoint();
+  int64_t outputZpVal = outputQType.getZeroPoint();
+
+  auto inputZp =
+      createZeroPointTensor(rewriter, op->getLoc(), newTosaConv2DOpType, 0);
+  auto outputZp = createZeroPointTensor(
+      rewriter, op->getLoc(), tosaConv2DOp.getOutput().getType(), outputZpVal);
+
+  if (!inputZp || !outputZp)
+    return failure();
 
   double opTensorScale = (inputScale * weightScale) / outputScale;
 
@@ -175,8 +183,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
       getConstTensorInt<int32_t>(rewriter, op->getLoc(), {multiplier}),
       getConstTensorInt<int8_t>(rewriter, op->getLoc(),
                                 {static_cast<int8_t>(shift)}),
-      /* input_zp = */ rewriter.getI32IntegerAttr(0),
-      /* output_zp = */ rewriter.getI32IntegerAttr(outputZp),
+      inputZp.value(), outputZp.value(),
       /* scale32 = */ rewriter.getBoolAttr(true),
       /* double_round = */ rewriter.getBoolAttr(true),
       /* per_channel = */ rewriter.getBoolAttr(false),



More information about the Mlir-commits mailing list