[Mlir-commits] [mlir] [mlir][tosa] Add support for EXT-DOUBLEROUND and EXT-INEXACTROUND (PR #130337)
TatWai Chong
llvmlistbot at llvm.org
Fri Mar 7 12:09:05 PST 2025
https://github.com/tatwaichong created https://github.com/llvm/llvm-project/pull/130337
Adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".
The validation pass has been updated to ensure "DOUBLE_ROUND"
and "INEXACT_ROUND" are only valid when their extensions are
available.
Finally, lowerings to arith and linalg have been updated such that a
lowering for "INEXACT_ROUND" is not currently supported.
>From 0acc7fd3815e0a00a2290c64ecd8c394b7565b11 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 15 Nov 2024 11:49:29 +0000
Subject: [PATCH] [mlir][tosa] Add support for EXT-DOUBLEROUND and
EXT-INEXACTROUND
This commit adds a concept of EXT-DOUBLEROUND and EXT-INEXACTROUND
to the dialect. It also converts the "double_round" attribute on rescale
to a string type "rounding_mode" attribute with the following options:
"DOUBLE_ROUND", "SINGLE_ROUND", "INEXACT_ROUND".
The validation pass has been updated to ensure "DOUBLE_ROUND" and
"INEXACT_ROUND" are only valid when their extensions are available.
Finally, lowerings to arith and linalg have been updated such that
a lowering for "INEXACT_ROUND" is not currently supported.
Co-authored-by: TatWai Chong <tatwai.chong at arm.com>
---
.../mlir/Dialect/Tosa/IR/TosaOpBase.td | 7 +++-
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 +-
.../Dialect/Tosa/IR/TosaProfileCompliance.h | 2 +
.../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 7 ++++
.../mlir/Dialect/Tosa/IR/TosaUtilOps.td | 2 +-
.../Conversion/TosaToArith/TosaToArith.cpp | 14 ++++++-
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 15 ++++---
.../TosaToLinalg/TosaToLinalgNamed.cpp | 2 +-
.../Tosa/Transforms/TosaValidation.cpp | 31 +++++++++++++-
.../TosaToArith/tosa-to-arith-invalid.mlir | 8 ++++
.../Conversion/TosaToArith/tosa-to-arith.mlir | 8 ++--
.../TosaToLinalg/tosa-to-linalg-invalid.mlir | 2 +-
.../TosaToLinalg/tosa-to-linalg-named.mlir | 2 +-
.../TosaToLinalg/tosa-to-linalg.mlir | 40 ++++++++++++-------
mlir/test/Dialect/Tosa/availability.mlir | 2 +-
mlir/test/Dialect/Tosa/canonicalize.mlir | 2 +-
mlir/test/Dialect/Tosa/invalid.mlir | 39 +++++++++---------
mlir/test/Dialect/Tosa/invalid_extension.mlir | 30 ++++++++++++++
mlir/test/Dialect/Tosa/level_check.mlir | 2 +-
mlir/test/Dialect/Tosa/ops.mlir | 4 +-
.../Dialect/Tosa/profile_all_unsupported.mlir | 2 +-
.../Tosa/profile_pro_fp_unsupported.mlir | 2 +-
.../Tosa/profile_pro_int_unsupported.mlir | 13 +++++-
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 2 +-
mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp | 2 +-
25 files changed, 180 insertions(+), 62 deletions(-)
create mode 100644 mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f2328003e49c5..db725dbd5e1bf 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -226,6 +226,8 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
// FFT : Fast Fourier Transform operations.
// VARIABLE : Stateful variable operations.
// CONTROLFLOW : Control Flow operations.
+// DOUBLEROUND : Adds double rounding support to the RESCALE operator.
+// INEXACTROUND : Adds inexact rounding support to the RESCALE operator.
//===----------------------------------------------------------------------===//
def Tosa_NONE : I32EnumAttrCase<"none", 0>;
@@ -241,11 +243,14 @@ def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>;
def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>;
def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>;
def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
+def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
+def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3,
- Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE
+ Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW,
+ Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND, Tosa_EXT_NONE
]>;
def Tosa_ExtensionArrayAttr
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5340ce52d73fc..3f87e299cbbdd 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2310,7 +2310,7 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
I32Attr:$input_zp,
I32Attr:$output_zp,
BoolAttr:$scale32,
- BoolAttr:$double_round,
+ Tosa_RoundingTypeAttr:$rounding_mode,
BoolAttr:$per_channel,
BoolAttr: $input_unsigned,
BoolAttr: $output_unsigned
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
index 969d06afc70d6..88f454f63e6f9 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
@@ -136,6 +136,8 @@ class TosaProfileCompliance {
switch (ext) {
case Extension::int16:
case Extension::int4:
+ case Extension::doubleround:
+ case Extension::inexactround:
return {Profile::pro_int};
case Extension::bf16:
case Extension::fp8e4m3:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 08c0e02139b0c..0038d8c386ca7 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -247,6 +247,13 @@ def Tosa_NanPropagationAttr : StringBasedAttr<
"::llvm::cast<StringAttr>($_self).getValue() == \"IGNORE\"">,
"Supported NaN propagation strategies">;
+// Rounding mode for tosa.rescale
+def Tosa_RoundingTypeAttr : StringBasedAttr<
+ CPred<"::llvm::cast<StringAttr>($_self).getValue() == \"SINGLE_ROUND\" || " #
+ "::llvm::cast<StringAttr>($_self).getValue() == \"INEXACT_ROUND\" || " #
+ "::llvm::cast<StringAttr>($_self).getValue() == \"DOUBLE_ROUND\"">,
+ "Supported rounding modes">;
+
def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">;
// Tensor to buffer types.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index 8756cb9e5de3a..8a27e5ba39331 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -44,7 +44,7 @@ def Tosa_ApplyScaleOp :
Tosa_IntLike:$value,
Tosa_IntLike:$multiplier,
Tosa_Int8Like:$shift,
- BoolAttr:$double_round
+ Tosa_RoundingTypeAttr:$rounding_mode
);
let results = (outs
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 5c84b4063da2e..9dea12355a519 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -65,6 +65,11 @@ class ApplyScaleGenericOpConverter
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
+ StringRef roundingMode = op.getRoundingMode();
+ if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ return failure();
+ }
+
Location loc = op.getLoc();
Value value = op.getValue();
Value multiplier32 = op.getMultiplier();
@@ -96,7 +101,7 @@ class ApplyScaleGenericOpConverter
multiply64 = rewriter.create<arith::AddIOp>(loc, multiply64, round);
// Apply double rounding if necessary.
- if (op.getDoubleRound()) {
+ if (op.getRoundingMode() == "DOUBLE_ROUND") {
int64_t roundInt = 1 << 30;
Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter);
Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter);
@@ -125,6 +130,11 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
LogicalResult matchAndRewrite(tosa::ApplyScaleOp op,
PatternRewriter &rewriter) const final {
+ StringRef roundingMode = op.getRoundingMode();
+ if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") {
+ return failure();
+ }
+
Location loc = op.getLoc();
Type resultTy = op.getType();
@@ -170,7 +180,7 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
rewriter.create<arith::SelectOp>(loc, shiftOver32, shiftHighR, zero32);
// Conditionally perform our double round.
- if (op.getDoubleRound()) {
+ if (op.getRoundingMode() == "DOUBLE_ROUND") {
Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter);
Value valuePositive = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, value32, zero32);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..b59e55302a60c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -170,7 +170,7 @@ static Value createLinalgBodyCalculationForElementwiseOp(
auto result = rewriter.create<tosa::ApplyScaleOp>(
loc, rewriter.getI32Type(), a, b, shiftConst,
- rewriter.getBoolAttr(false));
+ rewriter.getStringAttr("SINGLE_ROUND"));
if (elementTy.isInteger(32))
return result;
@@ -1374,7 +1374,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned rank = inputTy.getRank();
// This is an illegal configuration. terminate and log an error
- if (op.getDoubleRound() && !op.getScale32())
+ if (op.getRoundingMode() == "INEXACT_ROUND")
+ return rewriter.notifyMatchFailure(
+ op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not currently supported");
+ if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32())
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires scale32 for double_round to be true");
@@ -1418,9 +1421,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
+
bool doubleRound =
- op.getDoubleRound() &&
+ op.getRoundingMode() == "DOUBLE_ROUND" &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ StringAttr roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND") :
+ rewriter.getStringAttr("SINGLE_ROUND");
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1515,8 +1521,7 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
value = nestedBuilder.create<tosa::ApplyScaleOp>(
- loc, nestedBuilder.getI32Type(), value, multiplier, shift,
- nestedBuilder.getBoolAttr(doubleRound));
+ loc, nestedBuilder.getI32Type(), value, multiplier, shift, roundingMode);
// Move to the new zero-point.
value =
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..2dd3d2fb3325d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -1005,7 +1005,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
rewriter
.create<tosa::ApplyScaleOp>(loc, rewriter.getI32Type(),
poolVal, multiplier, shift,
- rewriter.getBoolAttr(false))
+ rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
// If we have quantization information we need to apply output
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b8604ef40cc93..70c4cd0a526cd 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -104,6 +104,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
LogicalResult applyLevelCheck(Operation *op);
+ LogicalResult applyAttributeCheck(Operation *op);
// check variable read/write data types against variable declarations
LogicalResult applyVariableCheck(Operation *op);
@@ -386,6 +387,23 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
return true;
}
+ bool attributeCheckRescale(Operation *op) {
+ if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) {
+ if (rescale.getRoundingMode() == "DOUBLE_ROUND" &&
+ !targetEnv.allows(Extension::doubleround)) {
+ op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND "
+ << "requires extension [doubleround]";
+ return false;
+ } else if (rescale.getRoundingMode() == "INEXACT_ROUND" &&
+ !targetEnv.allows(Extension::inexactround)) {
+ op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND "
+ << "requires extension [inexactround]";
+ return false;
+ }
+ }
+ return true;
+ }
+
// configure profile and level values from pass options profileName and
// levelName
void configLevelAndProfile() {
@@ -415,7 +433,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
} else {
llvm::errs() << "unknown TOSA extension name passed in: " << ext
<< ", supported extension are int16, int4, bf16, "
- << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n";
+ << "fp8e4m3, fp8e5m2, fft, variable, controlflow, "
+ << "doubleround and inexactround\n";
return signalPassFailure();
}
}
@@ -642,6 +661,12 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
return success();
}
+LogicalResult TosaValidation::applyAttributeCheck(Operation *op) {
+ if (!attributeCheckRescale(op))
+ return failure();
+ return success();
+}
+
inline bool CompatibleTypes(const mlir::Type &type,
const mlir::Type &declaredType) {
// for now, simply use type equality comparison
@@ -936,6 +961,10 @@ void TosaValidation::runOnOperation() {
if (failed(applyLevelCheck(op)))
signalPassFailure();
+ // check additional attribute restrictions
+ if (failed(applyAttributeCheck(op)))
+ signalPassFailure();
+
// do variable type checks
if (failed(applyVariableCheck(op)))
signalPassFailure();
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
new file mode 100644
index 0000000000000..4b324955439aa
--- /dev/null
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith-invalid.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics
+
+// CHECK-LABEL: @apply_scale_unsupported_inexact_round
+func.func @apply_scale_unsupported_inexact_round(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
+ // expected-error at +1 {{failed to legalize operation 'tosa.apply_scale'}}
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "INEXACT_ROUND"} : (i64, i32, i8) -> i32
+ return %res : i32
+}
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index 14f811727c456..db68ca40879f4 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -67,7 +67,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]]
// CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]]
// CHECK: return %[[RESULT]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
return %res : i32
}
@@ -77,7 +77,7 @@ func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) {
// SCALE: tosa.apply_scale
func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) {
// CHECK-NOT: "tosa.apply_scale"
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32>
return %res : vector<4xi32>
}
@@ -115,7 +115,7 @@ func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i48, i32, i8) -> i32
return %res : i32
}
@@ -152,6 +152,6 @@ func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) {
// CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]]
// CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32
// CHECK: return %[[TRUNC]]
- %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32
+ %res = tosa.apply_scale %arg0, %arg1, %arg2 {rounding_mode = "DOUBLE_ROUND"} : (i64, i32, i8) -> i32
return %res : i32
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..54c6ed994e947 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -36,7 +36,7 @@ func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32,
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..a89da9a2b9fed 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -411,7 +411,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
// CHECK: %[[TRUNC_SHIFT:.+]] = arith.trunci %[[SUB]]
// CHECK: %[[C30:.+]] = arith.constant 30
// CHECK: %[[SHIFT:.+]] = arith.addi %[[TRUNC_SHIFT]], %[[C30]] : i8
- // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {double_round = false}
+ // CHECK: %[[SCALED:.+]] = tosa.apply_scale %[[IN]], %[[TRUNC_MUL]], %[[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
// Perform the normalization.
// CHECK: %[[CMIN:.+]] = arith.constant -128
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..596069ad7f53d 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1141,7 +1141,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1151,7 +1151,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1172,7 +1172,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
@@ -1182,7 +1182,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1201,13 +1201,13 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
return
}
@@ -1227,7 +1227,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
%multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
return
}
@@ -1247,7 +1247,7 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {double_round = false}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1257,7 +1257,7 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
return
}
@@ -1279,7 +1279,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
// CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
// CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C243]]
- // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {double_round = false}
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C252]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
// CHECK-DAG: [[CMAX:%.+]] = arith.constant 127
@@ -1289,7 +1289,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
// CHECK-DAG: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<[42, 43, 44]> : tensor<3xi16> } : () -> tensor<3xi16>
%shift = "tosa.const"() {values = dense<[14, 15, 64]> : tensor<3xi8> } : () -> tensor<3xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<3xi8>, tensor<3xi16>, tensor<3xi8>) -> tensor<3xi8>
// CHECK: return [[GENERIC]]
return %0 : tensor<3xi8>
@@ -1301,10 +1301,10 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
// CHECK: linalg.generic
// CHECK: tosa.apply_scale
- // CHECK-SAME: {double_round = true}
+ // CHECK-SAME: {rounding_mode = "DOUBLE_ROUND"}
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
@@ -1312,10 +1312,20 @@ func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
// CHECK: linalg.generic
// CHECK: tosa.apply_scale
- // CHECK-SAME: {double_round = false}
+ // CHECK-SAME: {rounding_mode = "SINGLE_ROUND"}
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<33> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = true, rounding_mode = "INEXACT_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2xi8>
return %0 : tensor<2xi8>
}
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 1952ad79392c7..6ec0d548a2dee 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -614,7 +614,7 @@ func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
// CHECK: profiles: [ [pro_int] ]
// CHECK: extensions: [ [int16] ]
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 4242f68609634..bf7fbde01447e 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -905,7 +905,7 @@ func.func @reshape_quant_nofold() -> tensor<1x1x1x1xi32> {
%1 = tosa.reshape %0, %cst0 : (tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, !tosa.shape<4>) -> tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
- %2 = tosa.rescale %1, %multiplier, %shift {double_round = true, input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
+ %2 = tosa.rescale %1, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = -128 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1xi32>
return %2 : tensor<1x1x1x1xi32>
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 05700ca3765e4..51ee21fea23bd 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
// validation flow.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
func.func @test_const() -> tensor<1xf32> {
// expected-error at +1{{'tosa.const' op expected same attr/result element types}}
@@ -989,6 +989,7 @@ func.func @test_mismatch_in_out_shape_clz(%arg0: tensor<13x21x3xi32>) -> tensor<
}
// -----
+
// CHECK-LABEL: test_mismatch_in_out_data_type_cos
func.func @test_mismatch_in_out_data_type_cos(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf16> {
// expected-error at +1 {{'tosa.cos' op requires the same element type for all operands and results}}
@@ -1438,7 +1439,7 @@ func.func @test_rescale_invalid_input_type(%arg0: tensor<13x21x3xf32>) -> tensor
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect input to have integer element type, got 'f32'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xf32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1448,7 +1449,7 @@ func.func @test_rescale_invalid_output_type(%arg0: tensor<13x21x3xi32>) -> tenso
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect output to have integer element type, got 'f32'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1458,7 +1459,7 @@ func.func @test_rescale_invalid_multiplier_type(%arg0: tensor<13x21x3xi32>) -> t
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi48> } : () -> tensor<1xi48>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
// expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1xi48>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi48>, tensor<1xi16>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1468,7 +1469,7 @@ func.func @test_rescale_invalid_shift_type(%arg0: tensor<13x21x3xi32>) -> tensor
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi16> } : () -> tensor<1xi16>
// expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1xi16>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi16>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
@@ -1478,7 +1479,7 @@ func.func @test_rescale_invalid_input_zp_i32(%arg0: tensor<13x21x3xi32>) -> tens
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1488,7 +1489,7 @@ func.func @test_rescale_invalid_input_zp_s16(%arg0: tensor<13x21x3xi16>) -> tens
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect input_zp of 0, got 1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1498,7 +1499,7 @@ func.func @test_rescale_invalid_input_zp_u16(%arg0: tensor<13x21x3xi16>) -> tens
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect input_zp of 0 or 32768 for unsigned int16 input, got 1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 1 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1509,7 +1510,7 @@ func.func @test_rescale_invalid_output_zp_i32(%arg0: tensor<13x21x3xi32>) -> ten
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
}
@@ -1519,7 +1520,7 @@ func.func @test_rescale_invalid_output_zp_s16(%arg0: tensor<13x21x3xi16>) -> ten
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect output_zp of 0, got -1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1529,7 +1530,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1539,7 +1540,7 @@ func.func @test_rescale_invalid_multiplier_i16(%arg0: tensor<13x21x3xi16>) -> te
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect i32 element type for multiplier for scale32=true, got 'i16'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1549,7 +1550,7 @@ func.func @test_rescale_invalid_multiplier_i32(%arg0: tensor<13x21x3xi16>) -> te
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect i16 element type for multiplier for scale32=false, got 'i32'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1559,7 +1560,7 @@ func.func @test_rescale_invalid_multiplier_rank(%arg0: tensor<13x21x3xi16>) -> t
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1x1xi32> } : () -> tensor<1x1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op operand #1 must be 1D tensor of 16-bit signless integer or 32-bit signless integer values, but got 'tensor<1x1xi32>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1569,7 +1570,7 @@ func.func @test_rescale_invalid_shift_rank(%arg0: tensor<13x21x3xi16>) -> tensor
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1x1xi8> } : () -> tensor<1x1xi8>
// expected-error at +1 {{'tosa.rescale' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1x1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1579,7 +1580,7 @@ func.func @test_rescale_invalid_perchannel_multiplier_shape(%arg0: tensor<13x21x
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
// expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for multiplier input, got { 1 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1589,7 +1590,7 @@ func.func @test_rescale_invalid_non_perchannel_multiplier_shape(%arg0: tensor<13
%multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for multiplier input, got { 3 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1599,7 +1600,7 @@ func.func @test_rescale_invalid_perchannel_shift_shape(%arg0: tensor<13x21x3xi16
%multiplier = "tosa.const"() {values = dense<19689> : tensor<3xi32> } : () -> tensor<3xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op expect shape of { 3 } for shift input, got { 1 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<3xi32>, tensor<1xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
@@ -1609,6 +1610,6 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<3xi8> } : () -> tensor<3xi8>
// expected-error at +1 {{'tosa.rescale' op expect shape of { 1 } for shift input, got { 3 }}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<3xi8>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir
index 8352abc0c406e..544787c7a0699 100644
--- a/mlir/test/Dialect/Tosa/invalid_extension.mlir
+++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir
@@ -70,3 +70,33 @@ func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor<i32>) {
return
}
+// -----
+
+// CHECK-LABEL: test_single_round_rescale
+func.func @test_single_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // CHECK tosa.rescale
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_double_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = DOUBLE_ROUND requires extension [doubleround]}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %0 : tensor<13x21x3xi8>
+}
+
+// -----
+
+func.func @test_inexact_round_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op failed attribute check: rounding_mode = INEXACT_ROUND requires extension [inexactround]}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "INEXACT_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi8>
+ return %0 : tensor<13x21x3xi8>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index bc13b614e3f9d..0007b81ccfb87 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -435,7 +435,7 @@ func.func @test_rescale_rank_invalid(%arg0: tensor<1x1x1x1x13x21x3xi8>) -> tenso
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
// expected-error at +1 {{'tosa.rescale' op failed level check: operand rank(shape) <= MAX_RANK}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x1x13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x1x1x1x13x21x3xi8>
return %0 : tensor<1x1x1x1x13x21x3xi8>
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e1ac7d5f51d0e..b20c96363b2a2 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -96,7 +96,7 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
%2 = "tosa.conv2d"(%arg0, %0, %1, %izp, %wzp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi4>) -> tensor<1x1x1x3xi32>
%multiplier = "tosa.const"() {values = dense<[2026291432, 1079222024, 1693132724]> : tensor<3xi32>} : () -> tensor<3xi32>
%shift = "tosa.const"() {values = dense<[37, 36, 37]> : tensor<3xi8>} : () -> tensor<3xi8>
- %3 = tosa.rescale %2, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
+ %3 = tosa.rescale %2, %multiplier, %shift {rounding_mode = "DOUBLE_ROUND", input_zp = 0 : i32, output_zp = 27 : i32, per_channel = true, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x1x1x3xi32>, tensor<3xi32>, tensor<3xi8>) -> tensor<1x1x1x3xi8>
return %3 : tensor<1x1x1x3xi8>
}
@@ -722,7 +722,7 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
index 342c57b0dd85c..2cb32dd6edc5c 100644
--- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround strict-op-spec-alignment"
// -----
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
index 3dd0344e3647d..10c50019fcc89 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
// -----
func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x4x8xf32> {
diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
index 1d6d33b9a02c7..ee147d48a01b2 100644
--- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
+++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
@@ -2,7 +2,7 @@
// Enable all supported extensions to focus the verification of expected profile requirement errors.
//--------------------------------------------------------------------------------------------------
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround strict-op-spec-alignment"
// -----
func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {
@@ -24,3 +24,14 @@ func.func @test_cast_i8_i32(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi8> {
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xi8>
return %0 : tensor<13x21x3xi8>
}
+
+// -----
+func.func @test_rescale(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi32> {
+ // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+ %multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.const' op illegal: requires [pro_int] but not enabled in target}}
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op illegal: requires [pro_int] but not enabled in target}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32>
+ return %0 : tensor<13x21x3xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 55c5c3f6bdfb6..7785ddadaa29f 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -98,7 +98,7 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<2xi8>) -> () {
// CHECK: tosa.rescale %arg1, %[[MULT]], %[[SHIFT]] {{.+}} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<2xi8>
%multiplier = "tosa.const"() {values = dense<[42, 43]> : tensor<2xi16>} : () -> tensor<2xi16>
%shift = "tosa.const"() {values = dense<[14, 15]> : tensor<2xi8>} : () -> tensor<2xi8>
- %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, double_round = false, per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
+ %6 = tosa.rescale %arg1, %multiplier, %shift {input_zp = 243 : i32, output_zp = 252 : i32, scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = true, output_unsigned = true} : (tensor<2xi8>, tensor<2xi16>, tensor<2xi8>) -> tensor<*xi8>
// CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
%7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index a5d1cf863bbf8..1f8c180ab665d 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -178,7 +178,7 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
/* input_zp = */ rewriter.getI32IntegerAttr(0),
/* output_zp = */ rewriter.getI32IntegerAttr(outputZp),
/* scale32 = */ rewriter.getBoolAttr(true),
- /* double_round = */ rewriter.getBoolAttr(true),
+ /* double_round = */ rewriter.getStringAttr("DOUBLE_ROUND"),
/* per_channel = */ rewriter.getBoolAttr(false),
rewriter.getBoolAttr(inputUnsigned),
rewriter.getBoolAttr(outputUnsigned));
More information about the Mlir-commits
mailing list