[Mlir-commits] [mlir] [mlir][tosa] Enhance error_if and verify checks for RESCALE Op (PR #137021)

Peng Sun llvmlistbot at llvm.org
Wed Apr 23 10:18:06 PDT 2025


https://github.com/psunn updated https://github.com/llvm/llvm-project/pull/137021

>From e0d037fb53e0b6d3851a90d23c3059b2562aa4de Mon Sep 17 00:00:00 2001
From: Peng Sun <peng.sun at arm.com>
Date: Tue, 22 Apr 2025 21:12:30 +0000
Subject: [PATCH] [mlir][tosa] Enhance error_if and verify checks for RESCALE
 Op

  * add verifier for rank-0 input with per-channel
  * add checkErrorIfRescale to tosa validation pass that align with
TOSAv1.0
  * add LIT tests

Change-Id: Ia07e8c2ee66d8ee4113bea5ad9fa859b5986b009
Signed-off-by: Peng Sun <peng.sun at arm.com>
---
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |   6 +
 .../Tosa/Transforms/TosaValidation.cpp        |  82 ++++++++++++-
 mlir/test/Dialect/Tosa/error_if_check.mlir    | 108 ++++++++++++++++++
 mlir/test/Dialect/Tosa/invalid.mlir           |  12 ++
 mlir/test/Dialect/Tosa/verifier.mlir          |  13 +++
 5 files changed, 220 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 1ab4ce7d4558b..f1bed1241f971 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3186,6 +3186,12 @@ LogicalResult RescaleOp::verify() {
   // otherwise numChannel is dimension in input shape's last axis
   int64_t numChannels = 1;
   if (getPerChannel()) {
+    if (inputType.getRank() < 1) {
+      emitOpError("requires input to be at least rank 1 when per_channel is "
+                  "true, but got rank ")
+          << inputType.getRank();
+      return failure();
+    }
     numChannels = inputType.getDimSize(inputType.getRank() - 1);
   }
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index baa202833e285..06c2036923dfe 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1033,8 +1033,88 @@ bool checkErrorIfTable(Operation *op) {
   return true;
 }
 
+bool checkErrorIfRescale(Operation *op) {
+  auto rescale = dyn_cast<tosa::RescaleOp>(op);
+  if (!rescale)
+    return true;
+
+  auto inputType = llvm::dyn_cast<ShapedType>(rescale.getInput().getType());
+  auto outputType = llvm::dyn_cast<ShapedType>(rescale.getOutput().getType());
+  if (!inputType || !outputType || !inputType.getElementType().isInteger() ||
+      !outputType.getElementType().isInteger())
+    return true;
+
+  auto inElemType = inputType.getElementType();
+  auto outElemType = outputType.getElementType();
+  auto inWidth = inElemType.getIntOrFloatBitWidth();
+  auto outWidth = outElemType.getIntOrFloatBitWidth();
+
+  bool inputUnsigned = rescale.getInputUnsigned();
+  bool outputUnsigned = rescale.getOutputUnsigned();
+
+  bool scale32 = rescale.getScale32();
+  auto roundingMode = rescale.getRoundingMode();
+
+  // ERROR_IF(scale32 && is_same<in_t,i48_t>())
+  if (scale32 && inWidth == 48) {
+    op->emitOpError() << "scale32 is not allowed with 48-bit input.";
+    return false;
+  }
+
+  // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND))
+  if (!scale32 && roundingMode == "DOUBLE_ROUND") {
+    op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true.";
+    return false;
+  }
+
+  // ERROR_IF(input_unsigned && output_unsigned)
+  if (inputUnsigned && outputUnsigned) {
+    op->emitOpError() << "input and output cannot be both unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t,i32_t>() && input_unsigned)
+  if (outWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 output type is not allowed with unsigned input.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i32_t>() && output_unsigned)
+  if (inWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t,i48_t>() && output_unsigned)
+  if (inWidth == 48 && outputUnsigned) {
+    op->emitOpError() << "i48 input type is not allowed with unsigned output.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i48_t> && input_unsigned)
+  if (inWidth == 48 && inputUnsigned) {
+    op->emitOpError() << "i48 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<in_t, i32_t> && input_unsigned)
+  if (inWidth == 32 && inputUnsigned) {
+    op->emitOpError() << "i32 input type cannot be unsigned.";
+    return false;
+  }
+
+  // ERROR_IF(is_same<out_t, i32_t> && output_unsigned)
+  if (outWidth == 32 && outputUnsigned) {
+    op->emitOpError() << "i32 output type cannot be unsigned.";
+    return false;
+  }
+
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
-  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op))
+  if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
+      !checkErrorIfTable(op) || !checkErrorIfRescale(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 65a69be91e0c8..c6a173c92ff9a 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -129,3 +129,111 @@ func.func @test_i8_table_size(%arg0: tensor<2x64xi8>, %arg1: tensor<513xi8>) ->
     %0 = tosa.table %arg0, %arg1 : (tensor<2x64xi8>, tensor<513xi8>) -> tensor<2x64xi8>
     return %0 : tensor<2x64xi8>
 }
+
+// -----
+// CHECK-LABEL: test_error_input_zp_not_allowed
+func.func @test_error_input_zp_not_allowed(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_scale32_with_i48
+func.func @test_error_scale32_with_i48(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op scale32 is not allowed with 48-bit input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi48>, tensor<1xi32>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_input_output_unsigned
+func.func @test_error_input_output_unsigned(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op input and output cannot be both unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_output_unsigned_input
+func.func @test_error_i32_output_unsigned_input(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op i32 output type is not allowed with unsigned input}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_input_unsigned_output
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i32 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_input_unsigned_output
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i48 input type is not allowed with unsigned output}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i48_unsigned_input
+func.func @test_error_i48_input_unsigned_output(%arg0: tensor<1xi48>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i48 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_input
+func.func @test_error_i32_input_unsigned_output(%arg0: tensor<1xi32>) -> tensor<1xi8> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op i32 input type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<1xi32>, tensor<1xi16>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1xi8>
+  return %0 : tensor<1xi8>
+}
+
+// -----
+// CHECK-LABEL: test_error_i32_unsigned_output
+func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+  // expected-error at +1 {{'tosa.rescale' op i32 output type cannot be unsigned}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32>
+  return %0 : tensor<1xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 269ed58fdc81c..fe4cc49e89c0d 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1669,6 +1669,18 @@ func.func @test_rescale_invalid_non_perchannel_shift_shape(%arg0: tensor<13x21x3
   return %0 : tensor<13x21x3xi16>
 }
 
+// -----
+// CHECK-LABEL: test_error_double_round_without_scale32
+func.func @test_error_double_round_without_scale32(%arg0: tensor<1xi8>) -> tensor<1xi16> {
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op DOUBLE_ROUND is only allowed with scale32=true}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
 // -----
 // CHECK-LABEL: test_matmul_a_zp_same_element_type
 func.func @test_matmul_a_zp_same_element_type(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index fb8726cba1853..a42cf03a0a5cb 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -319,3 +319,16 @@ func.func @test_conv3d_wholly_divisible_output_width(%arg0: tensor<1x4x8x21x19xf
     : (tensor<1x4x8x21x19xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x19x34xf32>
   return %0 : tensor<1x4x8x19x34xf32>
 }
+
+// -----
+
+// CHECK-LABEL: test_error_scalar_input_with_per_channel
+func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor<i16> {
+  %multiplier = "tosa.const"() {values = dense<4> : tensor<1xi32> } : () -> tensor<1xi32>
+  %shift = "tosa.const"() {values = dense<2> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  // expected-error at +1 {{'tosa.rescale' op requires input to be at least rank 1 when per_channel is true, but got rank 0}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
+  return %0 : tensor<i16>
+}



More information about the Mlir-commits mailing list