[Mlir-commits] [mlir] [mlir][tosa] Enhance error_if and verify checks for RESCALE Op (PR #137021)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 23 10:10:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peng Sun (psunn)

<details>
<summary>Changes</summary>

  * add verifier for rank-0 input with per-channel
  * add checkErrorIfRescale to tosa validation pass that align with TOSAv1.0
  * add LIT tests

---
Full diff: https://github.com/llvm/llvm-project/pull/137021.diff


5 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+6) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+82-1) 
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+108) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+12) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+13) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1ab4ce7d4558b..f1bed1241f971 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3186,6 +3186,12 @@ LogicalResult RescaleOp::verify() {
   // otherwise numChannel is dimension in input shape's last axis
   int64_t numChannels = 1;
   if (getPerChannel()) {
+    if (inputType.getRank() < 1) {
+      emitOpError("requires input to be at least rank 1 when per_channel is "
+                  "true, but got rank ")
+          << inputType.getRank();
+      return failure();
+    }
     numChannels = inputType.getDimSize(inputType.getRank() - 1);
   }
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..fa337f350197c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1033,8 +1033,89 @@ bool checkErrorIfTable(Operation *op) {
   return true;
 }
 
+bool checkErrorIfRescale(Operation *op) {
+  auto rescale = dyn_cast<tosa::RescaleOp>(op);
+  if (!rescale)
+    return true;
+
+  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
+  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
+  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
+      !outputType.getElementType().isInteger())
+    return true;
+
+  auto inElemType = inputType.getElementType();
+  auto outElemType = outputType.getElementType();
+  auto inWidth = inElemType.getIntOrFloatBitWidth();
+  auto outWidth = outElemType.getIntOrFloatBitWidth();
+
+  bool inputUnsigned = rescale.getInputUnsigned();
+  bool outputUnsigned = rescale.getOutputUnsigned();
+
+  bool scale32 = rescale.getScale32();
+  auto roundingMode = rescale.getRoundingMode();
+
+
+  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
+  if (scale32 && inWidth == 48) {
+    op->emitOpError() << "scale32 is not allowed with 48-bit input.";
+    return false;
+  }
+
+  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
+  if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+    op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
+    return false;
+  }
+
+  // ERROR_IF(input_unsigned && output_unsigned)
+  if (inputUnsigned && outputUnsigned) {
+    op->emitOpError() << "input and output cannot be both unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
+  if (outWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 output type is not allowed with unsigned input.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
+  if (inWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
+  if (inWidth == 48 && outputUnsigned) {
+    op->emitOpError() << "i48 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
+  if (inWidth == 48 && inputUnsigned) {
+    op->emitOpError() << "i48 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
+  if (inWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
+  if (outWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 output type cannot be unsigned.";
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
+      !checkErrorIfTable(op) || !checkErrorIfRescale(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 65a69be91e0c8..c6a173c92ff9a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -129,3 +129,111 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
     %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
     return %0 : tensor<2x64xi8>
 }
+
+// -----
+// CHECK-LABEL: test_error_input_zp_not_allowed
+func.func @test_error_input_zp_not_allowed(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_scale32_with_i48
+func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_input_output_unsigned
+func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op input and output cannot be both unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_output_unsigned_input
+func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_input_unsigned_output
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_input_unsigned_output
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_unsigned_input
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_input
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_output
+func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..fe4cc49e89c0d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1669,6 +1669,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
   return %0 : tensor<13x21x3xi16>
 }
 
+// -----
+// CHECK-LABEL: test_error_double_round_without_scale32
+func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
 // -----
 // CHECK-LABEL: test_matmul_a_zp_same_element_type
 func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..a42cf03a0a5cb 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,16 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
     : (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
   return %0 : tensor<1x4x8x19x34xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_error_scalar_input_with_per_channel
+func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
+  %multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+  return %0 : tensor<i16>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/137021


More information about the Mlir-commits mailing list