[Mlir-commits] [mlir] e046f20 - [mlir][tosa] Enhance error_if and verify checks for RESCALE Op (#137021)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 25 01:50:45 PDT 2025


Author: Peng Sun
Date: 2025-04-25T09:50:42+01:00
New Revision: e046f2050578b8cda394143fafbfb476767b836a

URL: https://github.com/llvm/llvm-project/commit/e046f2050578b8cda394143fafbfb476767b836a
DIFF: https://github.com/llvm/llvm-project/commit/e046f2050578b8cda394143fafbfb476767b836a.diff

LOG: [mlir][tosa] Enhance error_if and verify checks for RESCALE Op (#137021)

* add verifier for rank-0 input with per-channel
* add checkErrorIfRescale to tosa validation pass that align with
TOSAv1.0
  * add LIT tests

Signed-off-by: Peng Sun <peng.sun at arm.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/error_if_check.mlir
    mlir/test/Dialect/Tosa/invalid.mlir
    mlir/test/Dialect/Tosa/verifier.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..b5504ca84fa42 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3206,6 +3206,12 @@ LogicalResult RescaleOp::verify() {
   // otherwise numChannel is dimension in input shape's last axis
   int64_t numChannels = 1;
   if (getPerChannel()) {
+    if (inputType.getRank() < 1) {
+      emitOpError("requires input to be at least rank 1 when per_channel is "
+                  "true, but got rank ")
+          << inputType.getRank();
+      return failure();
+    }
     numChannels = inputType.getDimSize(inputType.getRank() - 1);
   }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..06c2036923dfe 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1033,8 +1033,88 @@ bool checkErrorIfTable(Operation *op) {
   return true;
 }
 
+bool checkErrorIfRescale(Operation *op) {
+  auto rescale = dyn_cast<tosa::RescaleOp>(op);
+  if (!rescale)
+    return true;
+
+  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
+  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
+  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
+      !outputType.getElementType().isInteger())
+    return true;
+
+  auto inElemType = inputType.getElementType();
+  auto outElemType = outputType.getElementType();
+  auto inWidth = inElemType.getIntOrFloatBitWidth();
+  auto outWidth = outElemType.getIntOrFloatBitWidth();
+
+  bool inputUnsigned = rescale.getInputUnsigned();
+  bool outputUnsigned = rescale.getOutputUnsigned();
+
+  bool scale32 = rescale.getScale32();
+  auto roundingMode = rescale.getRoundingMode();
+
+  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
+  if (scale32 && inWidth == 48) {
+    op->emitOpError() << "scale32 is not allowed with 48-bit input.";
+    return false;
+  }
+
+  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
+  if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+    op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
+    return false;
+  }
+
+  // ERROR_IF(input_unsigned && output_unsigned)
+  if (inputUnsigned && outputUnsigned) {
+    op->emitOpError() << "input and output cannot be both unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
+  if (outWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 output type is not allowed with unsigned input.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
+  if (inWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
+  if (inWidth == 48 && outputUnsigned) {
+    op->emitOpError() << "i48 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
+  if (inWidth == 48 && inputUnsigned) {
+    op->emitOpError() << "i48 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
+  if (inWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
+  if (outWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 output type cannot be unsigned.";
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
+      !checkErrorIfTable(op) || !checkErrorIfRescale(op))
     return failure();
   return success();
 }

diff  --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 65a69be91e0c8..ac161128694cc 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -129,3 +129,99 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
     %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
     return %0 : tensor<2x64xi8>
 }
+
+// -----
+// CHECK-LABEL: test_error_scale32_with_i48
+func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_input_output_unsigned
+func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op input and output cannot be both unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_output_unsigned_input
+func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_input_unsigned_output
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_input_unsigned_output
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_unsigned_input
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_input
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_output
+func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}

diff  --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index b147c94fde9b0..4a341d583426a 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1638,6 +1638,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
   return %0 : tensor<13x21x3xi16>
 }
 
+// -----
+// CHECK-LABEL: test_error_double_round_without_scale32
+func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
 // -----
 // CHECK-LABEL: test_matmul_a_zp_same_element_type
 func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {

diff  --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..e88fc11d2be88 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,15 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
+// -----
+
+func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
+  %multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+  return %0 : tensor<i16>
+}


        


More information about the Mlir-commits mailing list