[Mlir-commits] [mlir] [mlir][tosa] Add support for EXT-DOUBLEROUND and EXT-INEXACTROUND (PR #130337)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 7 12:09:36 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: TatWai Chong (tatwaichong)

<details>
<summary>Changes</summary>

Adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".

The validation pass has been updated to ensure "DOUBLE_ROUND"
and "INEXACT_ROUND" are only valid when their extensions are
available.

Finally, lowerings to arith and linalg have been updated such that a
lowering for "INEXACT_ROUND" is not currently supported.

---

Patch is 63.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130337.diff


25 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+6-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h (+2) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td (+7) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td (+1-1) 
- (modified) mlir/lib/Conversion/TosaToArith/TosaToArith.cpp (+12-2) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+10-5) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+30-1) 
- (added) mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir (+8) 
- (modified) mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir (+4-4) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+1-1) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+1-1) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+25-15) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+20-19) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+30) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+12-1) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+1-1) 
- (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..db725dbd5e1bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
 // FFT          : Fast Fourier Transform operations.
 // VARIABLE     : Stateful variable operations.
 // CONTROLFLOW  : Control Flow operations.
+// DOUBLEROUND  : Adds double rounding support to the RESCALE operator.
+// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
 //===----------------------------------------------------------------------===//
 
 def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2      : I32EnumAttrCase<"fp8e5m2", 5>;
 def Tosa_EXT_FFT          : I32EnumAttrCase<"fft", 6>;
 def Tosa_EXT_VARIABLE     : I32EnumAttrCase<"variable", 7>;
 def Tosa_EXT_CONTROLFLOW  : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_DOUBLEROUND  : I32EnumAttrCase<"doubleround", 9>;
+def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
 
 def Tosa_ExtensionAttr
     : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
       Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
-      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
+      Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
+      Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
     ]>;
 
 def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5340ce52d73fc..3f87e299cbbdd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2310,7 +2310,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
     I32Attr:$input_zp,
     I32Attr:$output_zp,
     BoolAttr:$scale32,
-    BoolAttr:$double_round,
+    Tosa_RoundingTypeAttr:$rounding_mode,
     BoolAttr:$per_channel,
     BoolAttr: $input_unsigned,
     BoolAttr: $output_unsigned
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 969d06afc70d6..88f454f63e6f9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -136,6 +136,8 @@ class TosaProfileCompliance {
     switch (ext) {
     case Extension::int16:
     case Extension::int4:
+    case Extension::doubleround:
+    case Extension::inexactround:
       return {Profile::pro_int};
     case Extension::bf16:
     case Extension::fp8e4m3:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 08c0e02139b0c..0038d8c386ca7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
           "::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
     "Supported NaN propagation strategies">;
 
+// Rounding mode for tosa.rescale
+def Tosa_RoundingTypeAttr : StringBasedAttr<
+    CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\"  || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
+          "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
+    "Supported rounding modes">;
+
 def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
 
 // Tensor to buffer types.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 8756cb9e5de3a..8a27e5ba39331 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
     Tosa_IntLike:$value,
     Tosa_IntLike:$multiplier,
     Tosa_Int8Like:$shift,
-    BoolAttr:$double_round
+    Tosa_RoundingTypeAttr:$rounding_mode
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 5c84b4063da2e..9dea12355a519 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
+    StringRef roundingMode = op.getRoundingMode();
+    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+      return failure();
+    }
+
     Location loc = op.getLoc();
     Value value = op.getValue();
     Value multiplier32 = op.getMultiplier();
@@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
     multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
 
     // Apply double rounding if necessary.
-    if (op.getDoubleRound()) {
+    if (op.getRoundingMode() == "DOUBLE_ROUND") {
       int64_t roundInt = 1 << 30;
       Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
       Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
   LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
                                 PatternRewriter &rewriter) const final {
+    StringRef roundingMode = op.getRoundingMode();
+    if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+      return failure();
+    }
+
     Location loc = op.getLoc();
 
     Type resultTy = op.getType();
@@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
         rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
 
     // Conditionally perform our double round.
-    if (op.getDoubleRound()) {
+    if (op.getRoundingMode() == "DOUBLE_ROUND") {
       Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
       Value valuePositive = rewriter.create<arith::CmpIOp>(
           loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..b59e55302a60c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
         auto result = rewriter.create<tosa::ApplyScaleOp>(
             loc, rewriter.getI32Type(), a, b, shiftConst,
-            rewriter.getBoolAttr(false));
+            rewriter.getStringAttr("SINGLE_ROUND"));
 
         if (elementTy.isInteger(32))
           return result;
@@ -1374,7 +1374,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     unsigned rank = inputTy.getRank();
 
     // This is an illegal configuration. terminate and log an error
-    if (op.getDoubleRound() && !op.getScale32())
+    if (op.getRoundingMode() == "INEXACT_ROUND")
+      return rewriter.notifyMatchFailure(
+          op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not currently supported");
+    if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
       return rewriter.notifyMatchFailure(
           op, "tosa.rescale requires scale32 for double_round to be true");
 
@@ -1418,9 +1421,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 
     // Double round only occurs if shift is greater than 31, check that this
     // is ever true.
+
     bool doubleRound =
-        op.getDoubleRound() &&
+        op.getRoundingMode() == "DOUBLE_ROUND" &&
         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    StringAttr roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND") :
+        rewriter.getStringAttr("SINGLE_ROUND");
 
     SmallVector<AffineMap> indexingMaps = {
         rewriter.getMultiDimIdentityMap(rank)};
@@ -1515,8 +1521,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
 
           value = nestedBuilder.create<tosa::ApplyScaleOp>(
-              loc, nestedBuilder.getI32Type(), value, multiplier, shift,
-              nestedBuilder.getBoolAttr(doubleRound));
+              loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode);
 
           // Move to the new zero-point.
           value =
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..2dd3d2fb3325d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1005,7 +1005,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
                 rewriter
                     .create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
                                                 poolVal, multiplier, shift,
-                                                rewriter.getBoolAttr(false))
+                                                rewriter.getStringAttr("SINGLE_ROUND"))
                     .getResult();
 
             // If we have quantization information we need to apply output
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b8604ef40cc93..70c4cd0a526cd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   }
 
   LogicalResult applyLevelCheck(Operation *op);
+  LogicalResult applyAttributeCheck(Operation *op);
 
   // check variable read/write data types against variable declarations
   LogicalResult applyVariableCheck(Operation *op);
@@ -386,6 +387,23 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
     return true;
   }
 
+  bool attributeCheckRescale(Operation *op) {
+    if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
+      if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+          !targetEnv.allows(Extension::doubleround)) {
+        op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND "
+                          << "requires extension [doubleround]";
+        return false;
+      } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
+          !targetEnv.allows(Extension::inexactround)) {
+        op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND "
+                          << "requires extension [inexactround]";
+        return false;
+      }
+    }
+    return true;
+  }
+
   // configure profile and level values from pass options profileName and
   // levelName
   void configLevelAndProfile() {
@@ -415,7 +433,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
         } else {
           llvm::errs() << "unknown TOSA extension name passed in: " << ext
                        << ", supported extension are int16, int4, bf16, "
-                       << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
+                       << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
+                       << "doubleround and inexactround\n";
           return signalPassFailure();
         }
       }
@@ -642,6 +661,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
   return success();
 }
 
+LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
+  if (!attributeCheckRescale(op))
+    return failure();
+  return success();
+}
+
 inline bool CompatibleTypes(const mlir::Type &type,
                             const mlir::Type &declaredType) {
   // for now, simply use type equality comparison
@@ -936,6 +961,10 @@ void TosaValidation::runOnOperation() {
     if (failed(applyLevelCheck(op)))
       signalPassFailure();
 
+    // check additional attribute restrictions
+    if (failed(applyAttributeCheck(op)))
+      signalPassFailure();
+
     // do variable type checks
     if (failed(applyVariableCheck(op)))
       signalPassFailure();
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
new file mode 100644
index 0000000000000..4b324955439aa
--- /dev/null
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics
+
+// CHECK-LABEL: @apply_scale_unsupported_inexact_round
+func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+  // expected-error at +1 {{failed to legalize operation 'tosa.apply_scale'}}
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
+  return %res : i32
+}
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 14f811727c456..db68ca40879f4 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
   // CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
   // CHECK: return %[[RESULT]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
   return %res : i32
 }
 
@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
 // SCALE: tosa.apply_scale
 func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
   // CHECK-NOT: "tosa.apply_scale"
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
   return %res : vector<4xi32>
 }
 
@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
   // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
   // CHECK: return %[[TRUNC]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
   return %res : i32
 }
 
@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
   // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
   // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
   // CHECK: return %[[TRUNC]]
-  %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+  %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
   return %res : i32
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..54c6ed994e947 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
   %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
   %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
   // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
-  %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..a89da9a2b9fed 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -411,7 +411,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
   // CHECK: %[[C30:.+]] = arith.constant 30
   // CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
-  // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
+  // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
 
   // Perform the normalization.
   // CHECK: %[[CMIN:.+]] = arith.constant -128
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..596069ad7f53d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1141,7 +1141,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1151,7 +1151,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return
   return
@@ -1172,7 +1172,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
-  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1182,7 +1182,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : t...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/130337


More information about the Mlir-commits mailing list