[Mlir-commits] [mlir] [TOSA] Rescale output_zp fix (PR #136116)

Dmitriy Smirnov llvmlistbot at llvm.org
Thu Apr 17 05:51:35 PDT 2025


https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/136116

>From 981f7ae61b8174085c585e8e2a83e1ec7fff7bc2 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Thu, 17 Apr 2025 09:39:40 +0100
Subject: [PATCH] [TOSA] Rescale output_zp fix

Patch corrects output_zp in case of usigned output by removing
sign-extended part.
Also corrects check in verifyZeroPoint for int16 type

Change-Id: I1b36a7b267d2f38d972733101b777fbd5643c12d
---
 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp     | 8 ++++++++
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp                  | 3 ++-
 mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 6 +++---
 3 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e18fa849e9f30..9594bef3bc298 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1490,6 +1490,14 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
+          // pre-process OutputZP as it can be unsigned
+          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
+          APInt OZp(outBitwidth, !op.getOutputUnsigned());
+          OZp = static_cast<int64_t>(*maybeOZp);
+          *maybeOZp = op.getOutputUnsigned()
+                          ? static_cast<int64_t>(OZp.getZExtValue())
+                          : OZp.getSExtValue();
+
           auto outputZp = createConstOpFromZpVal<int32_t>(
               op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 59946ca54b933..8018fd69f5648 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1851,7 +1851,8 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
       return op.emitOpError()
              << "expect " << tensorName << "_zp of 0, got " << zp;
     }
-    if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+    if (zpElemType.isInteger(16) && tensorUnsigned &&
+        zp != static_cast<int16_t>(32768)) {
       return op.emitOpError() << "expect " << tensorName
                               << "_zp of 0 or 32768 for unsigned int16 "
                               << tensorName << ", got " << zp;
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9258442de5a45..4dbf926b80d2a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1161,11 +1161,11 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
   // CHECK: [[C17:%.+]] = arith.constant 17
-  // CHECK: [[C22:%.+]] = arith.constant 22
+  // CHECK: [[C234:%.+]] = arith.constant 234
   // CHECK-DAG: [[IN32:%.+]] = arith.extsi [[IN]]
   // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
-  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
   // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
   // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
@@ -1175,7 +1175,7 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
   %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
-  %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
   %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   // CHECK: return



More information about the Mlir-commits mailing list