[Mlir-commits] [mlir] [MLIR][TOSA] Fix validation for unsigned integer types in RescaleOp (PR #137838)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 29 09:28:38 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: NohHyeon Kwon (swote-git)

<details>
<summary>Changes</summary>

This patch fixes a bug in the TOSA RescaleOp verifier that incorrectly rejects unsigned integer types (ui8, ui16), even though they are supported by the TOSA specification.

The verifier now properly handles unsigned integer types when the corresponding input_unsigned or output_unsigned attribute is set to true.

Added tests for ui8<->i8 and ui16<->i16 rescale operations.

Fixes https://github.com/llvm/llvm-project/issues/135699

---
Full diff: https://github.com/llvm/llvm-project/pull/137838.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+18-12) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+22) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index b2e471f2bba93..980ef18b975f9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2111,24 +2111,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
                                      const int64_t &zp,
                                      const std::string &operand) {
   bool isInputZp = (operand == "Input");
-
   bool tensorUnsigned =
-      isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+    isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
   StringRef tensorName = isInputZp ? "input" : "output";
-
   Type zpElemType = getElementTypeOrSelf(zpVal);
 
   if (zp != 0) {
-    if (!zpElemType.isInteger(8) &&
-        !(zpElemType.isInteger(16) && tensorUnsigned)) {
-      return op.emitOpError()
-             << "expect " << tensorName << "_zp of 0, got " << zp;
+    bool validType = zpElemType.isInteger(8);
+
+    if (tensorUnsigned && zpElemType.isInteger(8)) {
+      validType = true;
     }
-    if (zpElemType.isInteger(16) && tensorUnsigned &&
-        zp != static_cast<int16_t>(32768)) {
-      return op.emitOpError() << "expect " << tensorName
-                              << "_zp of 0 or 32768 for unsigned int16 "
-                              << tensorName << ", got " << zp;
+
+    if (zpElemType.isInteger(16) && tensorUnsigned) {
+      validType = true;
+      if (zp != 32768) {
+        return op.emitOpError() << "expect " << tensorName
+                << "_zp of 0 or 32768 for unsigned int16 "
+                << tensorName << ", got " << zp;
+      }
+    }
+
+    if (!validType) {
+      return op.emitOpError() 
+              << "expect " << tensorName << "_zp of 0, got " << zp;
     }
   }
 
diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir
index 75126a11ac504..08d2bd30cf971 100644
--- a/mlir/test/Dialect/Tosa/availability.mlir
+++ b/mlir/test/Dialect/Tosa/availability.mlir
@@ -622,6 +622,28 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
   return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
+// -----
+// CHECK-LABEL: test_rescale
+func.func @test_rescale_unsigned_i8(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
+  %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+  // CHECK: tosa.rescale profiles: [ [pro_int] ]
+  // CHECK: tosa.rescale extensions: [ [int16] ]
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+  return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+}
+
+// -----
+// CHECK-LABEL: test_rescale
+func.func @test_rescale_to_unsigned_i8(%arg0: tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>> {
+  %input_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+  // CHECK: tosa.rescale profiles: [ [pro_int] ]
+  // CHECK: tosa.rescale extensions: [ [int16] ]
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", scale32 = true, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>
+  return %0 : tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>
+}
+
 // -----
 // CHECK-LABEL: test_const
 func.func @test_const(%arg0 : index) -> tensor<4xi32> {

``````````

</details>


https://github.com/llvm/llvm-project/pull/137838


More information about the Mlir-commits mailing list