[Mlir-commits] [mlir] [mlir][tosa] Add support for EXT-DOUBLEROUND and EXT-INEXACTROUND (PR #130337)

TatWai Chong llvmlistbot at llvm.org
Fri Mar 7 12:09:05 PST 2025


https://github.com/tatwaichong created https://github.com/llvm/llvm-project/pull/130337

Adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".

The validation pass has been updated to ensure "DOUBLE_ROUND"
and "INEXACT_ROUND" are only valid when their extensions are
available.

Finally, lowerings to arith and linalg have been updated such that a
lowering for "INEXACT_ROUND" is not currently supported.

>From 0acc7fd3815e0a00a2290c64ecd8c394b7565b11 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 15 Nov 2024 11:49:29 +0000
Subject: [PATCH] [mlir][tosa] Add support for EXT-DOUBLEROUND and
 EXT-INEXACTROUND

This commit adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".

The validation pass has been updated to ensure "DOUBLE_ROUND" and
"INEXACT_ROUND" are only valid when their extensions are available.

Finally, lowerings to arith and linalg have been updated such that
a lowering for "INEXACT_ROUND" is not currently supported.

Co-authored-by: TatWai Chong <tatwai.chong at arm.com>
---
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |  7 +++-
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  2 +-
 .../Dialect/Tosa/IR/TosaProfileCompliance.h   |  2 +
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |  7 ++++
 .../mlir/Dialect/Tosa/IR/TosaUtilOps.td       |  2 +-
 .../Conversion/TosaToArith/TosaToArith.cpp    | 14 ++++++-
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 15 ++++---
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  2 +-
 .../Tosa/Transforms/TosaValidation.cpp        | 31 +++++++++++++-
 .../TosaToArith/tosa-to-arith-invalid.mlir    |  8 ++++
 .../Conversion/TosaToArith/tosa-to-arith.mlir |  8 ++--
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir  |  2 +-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  2 +-
 .../TosaToLinalg/tosa-to-linalg.mlir          | 40 ++++++++++++-------
 mlir/test/Dialect/Tosa/availability.mlir      |  2 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  2 +-
 mlir/test/Dialect/Tosa/invalid.mlir           | 39 +++++++++---------
 mlir/test/Dialect/Tosa/invalid_extension.mlir | 30 ++++++++++++++
 mlir/test/Dialect/Tosa/level_check.mlir       |  2 +-
 mlir/test/Dialect/Tosa/ops.mlir               |  4 +-
 .../Dialect/Tosa/profile_all_unsupported.mlir |  2 +-
 .../Tosa/profile_pro_fp_unsupported.mlir      |  2 +-
 .../Tosa/profile_pro_int_unsupported.mlir     | 13 +++++-
 mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir |  2 +-
 mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp |  2 +-
 25 files changed, 180 insertions(+), 62 deletions(-)
 create mode 100644 mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..db725dbd5e1bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // FFT          : Fast Fourier Transform operations.
 // VARIABLE     : Stateful variable operations.
 // CONTROLFLOW  : Control Flow operations.
+// DOUBLEROUND  : Adds double rounding support to the RESCALE operator.
+// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2      : I32EnumAttrCase<"fp8e5m2", 5>;
 def Tosa_EXT_FFT          : I32EnumAttrCase<"fft", 6>;
 def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
 def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
+def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
-      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
+      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5340ce52d73fc..3f87e299cbbdd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2310,7 +2310,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     I32Attr:$input_zp,
     I32Attr:$output_zp,
     BoolAttr:$scale32,
-    BoolAttr:$double_round,
+    Tosa_RoundingTypeAttr:$rounding_mode,
     BoolAttr:$per_channel,
     BoolAttr: $input_unsigned,
     BoolAttr: $output_unsigned
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 969d06afc70d6..88f454f63e6f9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -136,6 +136,8 @@ class TosaProfileCompliance {
     switch (ext) {
     case Extension::int16:
     case Extension::int4:
+    case Extension::doubleround:
+    case Extension::inexactround:
       return {Profile::pro_int};
     case Extension::bf16:
     case Extension::fp8e4m3:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 08c0e02139b0c..0038d8c386ca7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
           "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
     "Supported NaN propagation strategies">;
 
+// Rounding mode for tosa.rescale
+def Tosa_RoundingTypeAttr : StringBasedAttr<
+    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\"  || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
+    "Supported rounding modes">;
+
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
 // Tensor to buffer types.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 8756cb9e5de3a..8a27e5ba39331 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
     Tosa_IntLike:$value,
     Tosa_IntLike:$multiplier,
     Tosa_Int8Like:$shift,
-    BoolAttr:$double_round
+    Tosa_RoundingTypeAttr:$rounding_mode
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 5c84b4063da2e..9dea12355a519 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
+    StringRef roundingMode = op.getRoundingMode();
+    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+      return failure();
+    }
+
     Location loc = op.getLoc();
     Value value = op.getValue();
     Value multiplier32 = op.getMultiplier();
@@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
     multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
 
     // Apply double rounding if necessary.
-    if (op.getDoubleRound()) {
+    if (op.getRoundingMode() == "DOUBLE_ROUND") {
       int64_t roundInt = 1 << 30;
       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
+    StringRef roundingMode = op.getRoundingMode();
+    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+      return failure();
+    }
+
     Location loc = op.getLoc();
 
     Type resultTy = op.getType();
@@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
         rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
 
     // Conditionally perform our double round.
-    if (op.getDoubleRound()) {
+    if (op.getRoundingMode() == "DOUBLE_ROUND") {
       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
       Value valuePositive = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..b59e55302a60c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
         auto result = rewriter.create<tosa::ApplyScaleOp>(
             loc, rewriter.getI32Type(), a, b, shiftConst,
-            rewriter.getBoolAttr(false));
+            rewriter.getStringAttr("SINGLE_ROUND"));
 
         if (elementTy.isInteger(32))
           return result;
@@ -1374,7 +1374,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     unsigned rank = inputTy.getRank();
 
     // This is an illegal configuration. terminate and log an error
-    if (op.getDoubleRound() && !op.getScale32())
+    if (op.getRoundingMode() == "INEXACT_ROUND")
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not currently supported");
+    if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
@@ -1418,9 +1421,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // Double round only occurs if shift is greater than 31, check that this
     // is ever true.
+
     bool doubleRound =
-        op.getDoubleRound() &&
+        op.getRoundingMode() == "DOUBLE_ROUND" &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    StringAttr roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND") :
+        rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1515,8 +1521,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
 
           value = nestedBuilder.create<tosa::ApplyScaleOp>(
-              loc, nestedBuilder.getI32Type(), value, multiplier, shift,
-              nestedBuilder.getBoolAttr(doubleRound));
+              loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode);
 
           // Move to the new zero-point.
           value =
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..2dd3d2fb3325d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1005,7 +1005,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
                 rewriter
                     .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
                                                 poolVal, multiplier, shift,
-                                                rewriter.getBoolAttr(false))
+                                                rewriter.getStringAttr("SINGLE_ROUND"))
                     .getResult();
 
             // If we have quantization information we need to apply output
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b8604ef40cc93..70c4cd0a526cd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   }
 
   LogicalResult applyLevelCheck(Operation *op);
+  LogicalResult applyAttributeCheck(Operation *op);
 
   // check variable read/write data types against variable declarations
   LogicalResult applyVariableCheck(Operation *op);
@@ -386,6 +387,23 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  bool attributeCheckRescale(Operation *op) {
+    if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
+      if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+          !targetEnv.allows(Extension::doubleround)) {
+        op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND "
+                          << "requires extension [doubleround]";
+        return false;
+      } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
+          !targetEnv.allows(Extension::inexactround)) {
+        op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND "
+                          << "requires extension [inexactround]";
+        return false;
+      }
+    }
+    return true;
+  }
+
   // configure profile and level values from pass options profileName and
   // levelName
   void configLevelAndProfile() {
@@ -415,7 +433,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         } else {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
-                       << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
+                       << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
+                       << "doubleround and inexactround\n";
           return signalPassFailure();
         }
       }
@@ -642,6 +661,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   return success();
 }
 
+LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
+  if (!attributeCheckRescale(op))
+    return failure();
+  return success();
+}
+
 inline bool CompatibleTypes(const mlir::Type &type,
                             const mlir::Type &declaredType) {
   // for now, simply use type equality comparison
@@ -936,6 +961,10 @@ void TosaValidation::runOnOperation() {
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
 
+    // check additional attribute restrictions
+    if (failed(applyAttributeCheck(op)))
+      signalPassFailure();
+
     // do variable type checks
     if (failed(applyVariableCheck(op)))
       signalPassFailure();
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
new file mode 100644
index 0000000000000..4b324955439aa
--- /dev/null
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics
+
+// CHECK-LABEL: @apply_scale_unsupported_inexact_round
+func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // expected-error at +1 {{failed to legalize operation 'tosa.apply_scale'}}
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
+  return %res : i32
+}
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 14f811727c456..db68ca40879f4 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
   // CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
   // CHECK: return %[[RESULT]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
   return %res : i32
 }
 
@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
 // SCALE: tosa.apply_scale
 func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
   // CHECK-NOT: "tosa.apply_scale"
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
   return %res : vector<4xi32>
 }
 
@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
   // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
   // CHECK: return %[[TRUNC]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
   return %res : i32
 }
 
@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
   // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
   // CHECK: return %[[TRUNC]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
   return %res : i32
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..54c6ed994e947 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..a89da9a2b9fed 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -411,7 +411,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
   // CHECK: %[[C30:.+]] = arith.constant 30
   // CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
-  // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
+  // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
 
   // Perform the normalization.
   // CHECK: %[[CMIN:.+]] = arith.constant -128
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..596069ad7f53d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1141,7 +1141,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1151,7 +1151,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1172,7 +1172,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1182,7 +1182,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1201,13 +1201,13 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
 
   return
 }
@@ -1227,7 +1227,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
   %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
   return
 }
 
@@ -1247,7 +1247,7 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1257,7 +1257,7 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
@@ -1279,7 +1279,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
 
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C243]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C252]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1289,7 +1289,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK-DAG: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<[42, 43, 44]> : tensor<3xi16> } : () -> tensor<3xi16>
   %shift = "tosa.const"() {values = dense<[14, 15, 64]> : tensor<3xi8> } : () -> tensor<3xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
 
   // CHECK: return [[GENERIC]]
   return %0 : tensor<3xi8>
@@ -1301,10 +1301,10 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
 func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
-  // CHECK-SAME:  {double_round = true}
+  // CHECK-SAME:  {rounding_mode = "DOUBLE_ROUND"}
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
@@ -1312,10 +1312,20 @@ func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
 func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
   // CHECK: tosa.apply_scale
-  // CHECK-SAME:  {double_round = false}
+  // CHECK-SAME:  {rounding_mode = "SINGLE_ROUND"}
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "INEXACT_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 1952ad79392c7..6ec0d548a2dee 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -614,7 +614,7 @@ func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   // CHECK: profiles: [ [pro_int] ]
   // CHECK: extensions: [ [int16] ]
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 4242f68609634..bf7fbde01447e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -905,7 +905,7 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
    %1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
    %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
-   %2 = tosa.rescale %1, %multiplier, %shift {double_round = true, input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
+   %2 = tosa.rescale %1, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
    return %2 : tensor<1x1x1x1xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 05700ca3765e4..51ee21fea23bd 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
 
 func.func @test_const() -> tensor<1xf32> {
   // expected-error at +1{{'tosa.const' op expected same attr/result element types}}
@@ -989,6 +989,7 @@ func.func @test_mismatch_in_out_shape_clz(%arg0: tensor<13x21x3xi32>) -> tensor<
 }
 
 // -----
+
 // CHECK-LABEL: test_mismatch_in_out_data_type_cos
 func.func @test_mismatch_in_out_data_type_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf16> {
   // expected-error at +1 {{'tosa.cos' op requires the same element type for all operands and results}}
@@ -1438,7 +1439,7 @@ func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
 
@@ -1448,7 +1449,7 @@ func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tenso
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1458,7 +1459,7 @@ func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> t
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi48> } : () -> tensor<1xi48>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1468,7 +1469,7 @@ func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
   // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
   return %0 : tensor<13x21x3xf32>
 }
 
@@ -1478,7 +1479,7 @@ func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tens
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
 
@@ -1488,7 +1489,7 @@ func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tens
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1498,7 +1499,7 @@ func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tens
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1509,7 +1510,7 @@ func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> ten
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
   return %0 : tensor<13x21x3xi32>
 }
 
@@ -1519,7 +1520,7 @@ func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> ten
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1529,7 +1530,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1539,7 +1540,7 @@ func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> te
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1549,7 +1550,7 @@ func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> te
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1559,7 +1560,7 @@ func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> t
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1569,7 +1570,7 @@ func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1x1xi8> } : () -> tensor<1x1xi8>
   // expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1579,7 +1580,7 @@ func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1589,7 +1590,7 @@ func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1599,7 +1600,7 @@ func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
 
@@ -1609,6 +1610,6 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
   // expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
   return %0 : tensor<13x21x3xi16>
 }
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 8352abc0c406e..544787c7a0699 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -70,3 +70,33 @@ func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: test_single_round_rescale
+func.func @test_single_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // CHECK tosa.rescale
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "INEXACT_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi8>
+  return %0 : tensor<13x21x3xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bc13b614e3f9d..0007b81ccfb87 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -435,7 +435,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
   return %0 : tensor<1x1x1x1x13x21x3xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e1ac7d5f51d0e..b20c96363b2a2 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -96,7 +96,7 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
   %2 = "tosa.conv2d"(%arg0, %0, %1, %izp, %wzp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi4>) -> tensor<1x1x1x3xi32>
   %multiplier = "tosa.const"() {values = dense<[2026291432, 1079222024, 1693132724]> : tensor<3xi32>} : () -> tensor<3xi32>
   %shift = "tosa.const"() {values = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
-  %3 = tosa.rescale %2, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
+  %3 = tosa.rescale %2, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
 
@@ -722,7 +722,7 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
    %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
    %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
-   %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+   %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
     return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 342c57b0dd85c..2cb32dd6edc5c 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 3dd0344e3647d..10c50019fcc89 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
 
 // -----
 func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 1d6d33b9a02c7..ee147d48a01b2 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
 // Enable all supported extensions to focus the verification of expected profile requirement errors.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
 
 // -----
 func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
@@ -24,3 +24,14 @@ func.func @test_cast_i8_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
   %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
   return %0 : tensor<13x21x3xi8>
 }
+
+// -----
+func.func @test_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+  // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+  %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op illegal: requires [pro_int] but not enabled in target}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+  return %0 : tensor<13x21x3xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 55c5c3f6bdfb6..7785ddadaa29f 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -98,7 +98,7 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
   // CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<2xi8>
   %multiplier = "tosa.const"() {values = dense<[42, 43]> : tensor<2xi16>} : () -> tensor<2xi16>
   %shift = "tosa.const"() {values = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
-  %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
+  %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
 
   // CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index a5d1cf863bbf8..1f8c180ab665d 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -178,7 +178,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
       /* input_zp = */ rewriter.getI32IntegerAttr(0),
       /* output_zp = */ rewriter.getI32IntegerAttr(outputZp),
       /* scale32 = */ rewriter.getBoolAttr(true),
-      /* double_round = */ rewriter.getBoolAttr(true),
+      /* double_round = */ rewriter.getStringAttr("DOUBLE_ROUND"),
       /* per_channel = */ rewriter.getBoolAttr(false),
       rewriter.getBoolAttr(inputUnsigned),
       rewriter.getBoolAttr(outputUnsigned));



More information about the Mlir-commits mailing list