[Mlir-commits] [mlir] [MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp (PR #138313)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 2 10:54:15 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Thomas Preud'homme (RoboTux)

<details>
<summary>Changes</summary>

Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.


---
Full diff: https://github.com/llvm/llvm-project/pull/138313.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+18-26) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+35-3) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
                                           rhsOrResult);
 }
 
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
 static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
-                       OpBuilder &rewriter) {
-  auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+                        bool isSigned, Location loc, OpBuilder &rewriter) {
+
+  // Zero the signed-extended bits if isSigned is false.
+  zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
   return rewriter.create<arith::ConstantOp>(
-      op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+      loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
 }
 
 static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Value value = blockArgs[0];
           Type valueTy = value.getType();
 
-          // For now we do all of our math in 64-bit. This is not optimal but
-          // should be correct for now, consider computing correct bit depth
-          // later.
-          int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
           if (failed(maybeIZp)) {
             (void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           }
 
-          auto inputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+          const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+          auto inputZp = createConstOpFromSExtZp(
+              *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
               nestedBuilder);
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
             return;
           };
 
-          // pre-process OutputZP as it can be unsigned
-          auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
-          APInt OZp(outBitwidth, !op.getOutputUnsigned());
-          OZp = static_cast<int64_t>(*maybeOZp);
-          *maybeOZp = op.getOutputUnsigned()
-                          ? static_cast<int64_t>(OZp.getZExtValue())
-                          : OZp.getSExtValue();
-
-          auto outputZp = createConstOpFromZpVal<int32_t>(
-              op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+          IntegerType outIntType =
+              cast<IntegerType>(blockArgs.back().getType());
+          unsigned outBitWidth = outIntType.getWidth();
+          auto outputZp = createConstOpFromSExtZp(
+              *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+              !op.getOutputUnsigned(), loc, nestedBuilder);
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
 
           // Saturate to the output size.
-          IntegerType outIntType =
-              cast<IntegerType>(blockArgs.back().getType());
-          unsigned outBitWidth = outIntType.getWidth();
-
           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: [[INIT:%.+]] = tensor.empty()
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
   // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
-  // CHECK: [[C17:%.+]] = arith.constant 17
+  // CHECK: [[C128:%.+]] = arith.constant 128
   // CHECK: [[C22:%.+]] = arith.constant 22
   // CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
-  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
   // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
   // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
   // CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
   // CHECK: linalg.yield [[TRUNC]]
   %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
   %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
-  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
   %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
   %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
 
   return
 }
 
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+  // CHECK: [[C19689:%.+]] = arith.constant 19689
+  // CHECK: [[C15:%.+]] = arith.constant 15
+  // CHECK: [[INIT:%.+]] = tensor.empty()
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+  // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+  // CHECK: [[C0:%.+]] = arith.constant 0
+  // CHECK: [[C234:%.+]] = arith.constant 234
+  // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+  // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+  // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+  // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+  // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+  // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+  // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+  // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+  // CHECK: linalg.yield [[TRUNC]]
+  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+  // CHECK: return
+  return
+}
+
 // -----
 
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>

``````````

</details>


https://github.com/llvm/llvm-project/pull/138313


More information about the Mlir-commits mailing list