[Mlir-commits] [mlir] [MLIR][TOSA-Linalg] Fix rescale lowering for unsigned input zp (PR #138313)
Thomas Preud'homme
llvmlistbot at llvm.org
Tue May 6 07:01:24 PDT 2025
https://github.com/RoboTux updated https://github.com/llvm/llvm-project/pull/138313
>From cd1e65dbb2575d57fe2f21af57125f29a793ccda Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Fri, 2 May 2025 15:50:25 +0100
Subject: [PATCH 1/3] [MLIR][TOSA-Linalg] Fix rescale lowering for unsigned
input zp
Lowering of tosa.rescale to Linalg unconditionally sign-extend the input
zero-point value, even when unsigned_input is true. This commit refactor
zeropoint handling to share the same logic between input and output
zeropoint.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 44 ++++++++-----------
.../TosaToLinalg/tosa-to-linalg.mlir | 38 ++++++++++++++--
2 files changed, 53 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 95364c26d1a7d..857f2721e1328 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,13 +82,16 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
rhsOrResult);
}
-template <typename T>
+// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
static arith::ConstantOp
-createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
- OpBuilder &rewriter) {
- auto castedN = static_cast<T>(zp);
+createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
+ bool isSigned, Location loc, OpBuilder &rewriter) {
+
+ // Zero the signed-extended bits if isSigned is false.
+ zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
+
return rewriter.create<arith::ConstantOp>(
- op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
+ loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
}
static Value createLinalgBodyCalculationForElementwiseOp(
@@ -1467,11 +1470,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Value value = blockArgs[0];
Type valueTy = value.getType();
- // For now we do all of our math in 64-bit. This is not optimal but
- // should be correct for now, consider computing correct bit depth
- // later.
- int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
-
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
@@ -1479,8 +1477,10 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
}
- auto inputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
+ const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
+ const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
+ auto inputZp = createConstOpFromSExtZp(
+ *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
nestedBuilder);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
@@ -1490,16 +1490,12 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
return;
};
- // pre-process OutputZP as it can be unsigned
- auto outBitwidth = outputTy.getElementType().getIntOrFloatBitWidth();
- APInt OZp(outBitwidth, !op.getOutputUnsigned());
- OZp = static_cast<int64_t>(*maybeOZp);
- *maybeOZp = op.getOutputUnsigned()
- ? static_cast<int64_t>(OZp.getZExtValue())
- : OZp.getSExtValue();
-
- auto outputZp = createConstOpFromZpVal<int32_t>(
- op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
+ IntegerType outIntType =
+ cast<IntegerType>(blockArgs.back().getType());
+ unsigned outBitWidth = outIntType.getWidth();
+ auto outputZp = createConstOpFromSExtZp(
+ *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
+ !op.getOutputUnsigned(), loc, nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
@@ -1527,10 +1523,6 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
// Saturate to the output size.
- IntegerType outIntType =
- cast<IntegerType>(blockArgs.back().getType());
- unsigned outBitWidth = outIntType.getWidth();
-
int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 7083d19f4372a..185f1973ecdc6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1241,10 +1241,10 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>)
// CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8):
- // CHECK: [[C17:%.+]] = arith.constant 17
+ // CHECK: [[C128:%.+]] = arith.constant 128
// CHECK: [[C22:%.+]] = arith.constant 22
// CHECK-DAG: [[IN32:%.+]] = arith.extui [[IN]]
- // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C17]]
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN32]], [[C128]]
// CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C0]], [[C1]] {rounding_mode = "SINGLE_ROUND"}
// CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C22]]
// CHECK-DAG: [[CMIN:%.+]] = arith.constant -128
@@ -1255,13 +1255,45 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
return
}
+// -----
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
+
+// CHECK-LABEL: @rescale_i48_unsigned_output
+// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
+func.func @rescale_i48_unsigned_output(%arg0 : tensor<2xi48>) -> () {
+ // CHECK: [[C19689:%.+]] = arith.constant 19689
+ // CHECK: [[C15:%.+]] = arith.constant 15
+ // CHECK: [[INIT:%.+]] = tensor.empty()
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<2xi48>) outs([[INIT]] : tensor<2xi8>)
+ // CHECK: ^bb0([[IN:%.+]]: i48, [[UNUSED:%.+]]: i8):
+ // CHECK: [[C0:%.+]] = arith.constant 0
+ // CHECK: [[C234:%.+]] = arith.constant 234
+ // CHECK-DAG: [[IN_ZEROED:%.+]] = arith.subi [[IN]], [[C0]]
+ // CHECK-DAG: [[SCALED:%.+]] = tosa.apply_scale [[IN_ZEROED]], [[C19689]], [[C15]] {rounding_mode = "SINGLE_ROUND"}
+ // CHECK-DAG: [[SCALED_ZEROED:%.+]] = arith.addi [[SCALED]], [[C234]]
+ // CHECK-DAG: [[CMIN:%.+]] = arith.constant 0
+ // CHECK-DAG: [[CMAX:%.+]] = arith.constant 255
+ // CHECK-DAG: [[LOWER:%.+]] = arith.maxsi [[CMIN]], [[SCALED_ZEROED]]
+ // CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
+ // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
+ // CHECK: linalg.yield [[TRUNC]]
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi48>} : () -> tensor<1xi48>
+ %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi8>) -> tensor<2xi8>
+
+ // CHECK: return
+ return
+}
+
// -----
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
>From 1c363953e6ff0c2f7ce28c7c0a6cece4245b3682 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 6 May 2025 14:39:51 +0100
Subject: [PATCH 2/3] Deal with zeropoint sign in getter
Also clarify zeropoint extension rules.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 29 +++++--------
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 43 ++++++++++---------
mlir/test/Dialect/Tosa/invalid.mlir | 2 +-
3 files changed, 34 insertions(+), 40 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 857f2721e1328..b9afbe21b1c6b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -82,18 +82,6 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
rhsOrResult);
}
-// Create i32 Value from zp, a zero-point value sign-extended from bitwidth.
-static arith::ConstantOp
-createConstOpFromSExtZp(int64_t zp, unsigned zpBitwidth, unsigned attrBitwidth,
- bool isSigned, Location loc, OpBuilder &rewriter) {
-
- // Zero the signed-extended bits if isSigned is false.
- zp = isSigned ? zp : zp & ((1 << zpBitwidth) - 1);
-
- return rewriter.create<arith::ConstantOp>(
- loc, IntegerAttr::get(rewriter.getIntegerType(attrBitwidth), zp));
-}
-
static Value createLinalgBodyCalculationForElementwiseOp(
Operation *op, ValueRange args, ArrayRef<Type> resultTypes,
ConversionPatternRewriter &rewriter) {
@@ -1478,10 +1466,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- const int32_t attrBitwidth = inBitwidth > 32 ? 48 : 32;
- auto inputZp = createConstOpFromSExtZp(
- *maybeIZp, inBitwidth, attrBitwidth, !op.getInputUnsigned(), loc,
- nestedBuilder);
+ // Extend zeropoint for sub-32bits widths.
+ const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
+ auto inputZp = nestedBuilder.create<arith::ConstantOp>(loc,
+ IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
+ *maybeIZp));
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
@@ -1493,9 +1482,11 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- auto outputZp = createConstOpFromSExtZp(
- *maybeOZp, outBitWidth, /*attrBitwidth=*/32,
- !op.getOutputUnsigned(), loc, nestedBuilder);
+ const int32_t outAttrBitwidth = 32;
+ assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
+ auto outputZp = nestedBuilder.create<arith::ConstantOp>(loc,
+ IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
+ *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index de06b621cbe3d..2f9c6d7870782 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2118,7 +2118,7 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
// return failure if val is not a constant
// set zp to -1 if val is non-zero float or val is not integer nor float
// otherwise set zp to val's constant value
-static FailureOr<int64_t> getZeroPoint(Value val) {
+static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
ElementsAttr zpAttr;
if (!matchPattern(val, m_Constant(&zpAttr))) {
return failure();
@@ -2135,7 +2135,10 @@ static FailureOr<int64_t> getZeroPoint(Value val) {
}
if (llvm::isa<IntegerType>(zpElemType)) {
- return zpAttr.getValues<APInt>()[0].getSExtValue();
+ if (signExtend)
+ return zpAttr.getValues<APInt>()[0].getSExtValue();
+ else
+ return zpAttr.getValues<APInt>()[0].getZExtValue();
}
// return non-zero value to trigger error check
@@ -2186,30 +2189,30 @@ static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
return success();
}
-#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
+#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
- return getZeroPoint(get##OPERAND_NAME##Zp()); \
+ return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
} \
LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
}
-ZERO_POINT_HELPER(Conv2DOp, Input)
-ZERO_POINT_HELPER(Conv2DOp, Weight)
-ZERO_POINT_HELPER(Conv3DOp, Input)
-ZERO_POINT_HELPER(Conv3DOp, Weight)
-ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
-ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
-ZERO_POINT_HELPER(TransposeConv2DOp, Input)
-ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
-ZERO_POINT_HELPER(AvgPool2dOp, Input)
-ZERO_POINT_HELPER(AvgPool2dOp, Output)
-ZERO_POINT_HELPER(MatMulOp, A)
-ZERO_POINT_HELPER(MatMulOp, B)
-ZERO_POINT_HELPER(NegateOp, Input1)
-ZERO_POINT_HELPER(NegateOp, Output)
-ZERO_POINT_HELPER(RescaleOp, Input)
-ZERO_POINT_HELPER(RescaleOp, Output)
+ZERO_POINT_HELPER(Conv2DOp, Input, true)
+ZERO_POINT_HELPER(Conv2DOp, Weight, true)
+ZERO_POINT_HELPER(Conv3DOp, Input, true)
+ZERO_POINT_HELPER(Conv3DOp, Weight, true)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
+ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
+ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
+ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
+ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
+ZERO_POINT_HELPER(MatMulOp, A, true)
+ZERO_POINT_HELPER(MatMulOp, B, true)
+ZERO_POINT_HELPER(NegateOp, Input1, true)
+ZERO_POINT_HELPER(NegateOp, Output, true)
+ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
+ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 9ccb310c4491d..56d76585be71b 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1517,7 +1517,7 @@ func.func @test_rescale_invalid_output_zp_u16(%arg0: tensor<13x21x3xi16>) -> ten
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi16>} : () -> tensor<1xi16>
- // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got -1}}
+ // expected-error at +1 {{'tosa.rescale' op expect output_zp of 0 or 32768 for unsigned int16 output, got 65535}}
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = true} : (tensor<13x21x3xi16>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi16>) -> tensor<13x21x3xi16>
return %0 : tensor<13x21x3xi16>
}
>From f7e1988bf5aeaa48a174b755f58e3044580f146f Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Tue, 6 May 2025 15:00:08 +0100
Subject: [PATCH 3/3] Fix formatting
---
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b9afbe21b1c6b..0b69cd2814fb9 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1468,9 +1468,9 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = nestedBuilder.create<arith::ConstantOp>(loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = nestedBuilder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
+ *maybeIZp));
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
@@ -1484,9 +1484,9 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned outBitWidth = outIntType.getWidth();
const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = nestedBuilder.create<arith::ConstantOp>(loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
+ auto outputZp = nestedBuilder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
+ *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
More information about the Mlir-commits
mailing list