[Mlir-commits] [mlir] [mlir][tosa] Enhance error_if and verify checks for RESCALE Op (PR #137021)
Peng Sun
llvmlistbot at llvm.org
Wed Apr 23 10:18:06 PDT 2025
https://github.com/psunn updated https://github.com/llvm/llvm-project/pull/137021
>From e0d037fb53e0b6d3851a90d23c3059b2562aa4de Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Tue, 22 Apr 2025 21:12:30 +0000
Subject: [PATCH] [mlir][tosa] Enhance error_if and verify checks for RESCALE
Op
* add verifier for rank-0 input with per-channel
* add checkErrorIfRescale to tosa validation pass that align with
TOSAv1.0
* add LIT tests
Change-Id: Ia07e8c2ee66d8ee4113bea5ad9fa859b5986b009
Signed-off-by: Peng Sun <peng.sun at arm.com>
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 +
.../Tosa/Transforms/TosaValidation.cpp | 82 ++++++++++++-
mlir/test/Dialect/Tosa/error_if_check.mlir | 108 ++++++++++++++++++
mlir/test/Dialect/Tosa/invalid.mlir | 12 ++
mlir/test/Dialect/Tosa/verifier.mlir | 13 +++
5 files changed, 220 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1ab4ce7d4558b..f1bed1241f971 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3186,6 +3186,12 @@ LogicalResult RescaleOp::verify() {
// otherwise numChannel is dimension in input shape's last axis
int64_t numChannels = 1;
if (getPerChannel()) {
+ if (inputType.getRank() < 1) {
+ emitOpError("requires input to be at least rank 1 when per_channel is "
+ "true, but got rank ")
+ << inputType.getRank();
+ return failure();
+ }
numChannels = inputType.getDimSize(inputType.getRank() - 1);
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..06c2036923dfe 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1033,8 +1033,88 @@ bool checkErrorIfTable(Operation *op) {
return true;
}
+bool checkErrorIfRescale(Operation *op) {
+ auto rescale = dyn_cast<tosa::RescaleOp>(op);
+ if (!rescale)
+ return true;
+
+ auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
+ auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
+ if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
+ !outputType.getElementType().isInteger())
+ return true;
+
+ auto inElemType = inputType.getElementType();
+ auto outElemType = outputType.getElementType();
+ auto inWidth = inElemType.getIntOrFloatBitWidth();
+ auto outWidth = outElemType.getIntOrFloatBitWidth();
+
+ bool inputUnsigned = rescale.getInputUnsigned();
+ bool outputUnsigned = rescale.getOutputUnsigned();
+
+ bool scale32 = rescale.getScale32();
+ auto roundingMode = rescale.getRoundingMode();
+
+ // ERROR_IF(scale32 && is_same<in_t,i48_t>())
+ if (scale32 && inWidth == 48) {
+ op->emitOpError() << "scale32 is not allowed with 48-bit input.";
+ return false;
+ }
+
+ // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
+ if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+ op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
+ return false;
+ }
+
+ // ERROR_IF(input_unsigned && output_unsigned)
+ if (inputUnsigned && outputUnsigned) {
+ op->emitOpError() << "input and output cannot be both unsigned.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
+ if (outWidth == 32 && inputUnsigned) {
+ op->emitOpError() << "i32 output type is not allowed with unsigned input.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
+ if (inWidth == 32 && outputUnsigned) {
+ op->emitOpError() << "i32 input type is not allowed with unsigned output.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
+ if (inWidth == 48 && outputUnsigned) {
+ op->emitOpError() << "i48 input type is not allowed with unsigned output.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
+ if (inWidth == 48 && inputUnsigned) {
+ op->emitOpError() << "i48 input type cannot be unsigned.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
+ if (inWidth == 32 && inputUnsigned) {
+ op->emitOpError() << "i32 input type cannot be unsigned.";
+ return false;
+ }
+
+ // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
+ if (outWidth == 32 && outputUnsigned) {
+ op->emitOpError() << "i32 output type cannot be unsigned.";
+ return false;
+ }
+
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
- if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
+ if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
+ !checkErrorIfTable(op) || !checkErrorIfRescale(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 65a69be91e0c8..c6a173c92ff9a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -129,3 +129,111 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
%0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
return %0 : tensor<2x64xi8>
}
+
+// -----
+// CHECK-LABEL: test_error_input_zp_not_allowed
+func.func @test_error_input_zp_not_allowed(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_scale32_with_i48
+func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_input_output_unsigned
+func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.rescale' op input and output cannot be both unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+ return %0 : tensor<1xi16>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_output_unsigned_input
+func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_input_unsigned_output
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_input_unsigned_output
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_unsigned_input
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_input
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+ return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_output
+func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ // expected-error at +1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+ return %0 : tensor<1xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..fe4cc49e89c0d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1669,6 +1669,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_error_double_round_without_scale32
+func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+ return %0 : tensor<1xi16>
+}
+
// -----
// CHECK-LABEL: test_matmul_a_zp_same_element_type
func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..a42cf03a0a5cb 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,16 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
: (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
return %0 : tensor<1x4x8x19x34xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_error_scalar_input_with_per_channel
+func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
+ %multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+ // expected-error at +1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+ return %0 : tensor<i16>
+}
More information about the Mlir-commits
mailing list