[Mlir-commits] [mlir] [mlir][tosa] Change Rescale zero points to be inputs (PR #130340)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 7 12:22:18 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Tai Ly (Tai78641)
<details>
<summary>Changes</summary>
*Update RescaleOp to use zero-point as operands instead of attributes.
*Check input_zp data type against the input and output_zp data type
against the output.
---
Patch is 61.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130340.diff
15 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13-10)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+9-2)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+22-8)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+69-34)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+2)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+4-1)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+45-24)
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+7-5)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+3-1)
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+54-18)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+5-3)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+6-2)
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+15)
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+6-2)
- (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+10-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..08f28a7538c3d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -216,15 +216,15 @@ profileComplianceMap = {
{fp32T, fp16T}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
- {{i8T, i8T},
- {i8T, i16T},
- {i8T, i32T},
- {i16T, i8T},
- {i16T, i16T},
- {i16T, i32T},
- {i32T, i8T},
- {i32T, i16T},
- {i32T, i32T}}}}},
+ {{i8T, i8T, i8T, i8T},
+ {i8T, i8T, i16T, i16T},
+ {i8T, i8T, i32T, i32T},
+ {i16T, i16T, i8T, i8T},
+ {i16T, i16T, i16T, i16T},
+ {i16T, i16T, i32T, i32T},
+ {i32T, i32T, i8T, i8T},
+ {i32T, i32T, i16T, i16T},
+ {i32T, i32T, i32T, i32T}}}}},
{"tosa.const",
{{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
{{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
@@ -384,7 +384,10 @@ extensionComplianceMap = {
{fp16T, fp8e5m2T},
{fp32T, fp8e5m2T}}}}},
{"tosa.rescale",
- {{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
+ {{{Extension::int16},
+ {{i48T, i48T, i8T, i8T},
+ {i48T, i48T, i16T, i16T},
+ {i48T, i48T, i32T, i32T}}}}},
{"tosa.const",
{{{Extension::int4}, {{i4T}}},
{{Extension::int16}, {{i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..cd593a2816355 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2337,8 +2337,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Tosa_Tensor:$input,
Tosa_1DInt16Or32Tensor:$multiplier,
Tosa_1DInt8Tensor:$shift,
- I32Attr:$input_zp,
- I32Attr:$output_zp,
+ Tosa_ScalarIntOrFloatTensor:$input_zp,
+ Tosa_ScalarIntOrFloatTensor:$output_zp,
BoolAttr:$scale32,
BoolAttr:$double_round,
BoolAttr:$per_channel,
@@ -2355,6 +2355,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Extension<[Tosa_EXT_INT16]>,
];
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getInputZeroPoint();
+ FailureOr<int64_t> getOutputZeroPoint();
+ LogicalResult verifyInputZeroPoint(int64_t zp);
+ LogicalResult verifyOutputZeroPoint(int64_t zp);
+ }];
+
let hasVerifier = 1;
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f7dd33c7e8b53..89af54132f820 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,
template <typename T>
static arith::ConstantOp
-createConstFromIntAttribute(Operation *op, const std::string &attrName,
- Type requiredAttrType, OpBuilder &rewriter) {
- auto castedN = static_cast<T>(
- cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
+createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
+ OpBuilder &rewriter) {
+ auto castedN = static_cast<T>(zp);
return rewriter.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
@@ -1491,11 +1490,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// later.
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
- auto inputZp = createConstFromIntAttribute<int32_t>(
- op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ if (failed(maybeIZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "input zero point cannot be statically determined");
+ return;
+ }
+
+ auto inputZp = createConstOpFromZpVal<int32_t>(
+ op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
- auto outputZp = createConstFromIntAttribute<int32_t>(
- op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
+
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ if (failed(maybeOZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return;
+ };
+
+ auto outputZp = createConstOpFromZpVal<int32_t>(
+ op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..4d1f0be567d3c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -254,6 +254,27 @@ static Type getStorageElementTypeOrSelf(Type type) {
return elementType;
}
+static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
+ Value valZp, StringRef name) {
+ Type eType = getStorageElementTypeOrSelf(val.getType());
+ Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
+
+ bool bothInts =
+ mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
+ bool bothFloats =
+ mlir::isa<FloatType>(eType) && mlir::isa<FloatType>(eZpType);
+ bool sameBitWidth =
+ (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
+
+ if ((!bothInts && !bothFloats) || !sameBitWidth) {
+ return op->emitOpError()
+ << "expected " << name << " and " << name
+ << "_zp to both be integer or float of the same bitwidth, but got "
+ << eType << " vs. " << eZpType;
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
@@ -1696,6 +1717,38 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
return success();
}
+static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
+ const int64_t &zp,
+ const std::string &operand) {
+ bool isInputZp = (zpVal == op.getInputZp());
+ bool isOutputZp = (zpVal == op.getOutputZp());
+ if (!isInputZp && !isOutputZp) {
+ return op.emitOpError("internal error: zero-point operand is neither "
+ "inputZp nor outputZp");
+ }
+
+ bool tensorUnsigned =
+ isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
+ StringRef tensorName = isInputZp ? "input" : "output";
+
+ Type zpElemType = getElementTypeOrSelf(zpVal);
+
+ if (zp != 0) {
+ if (!zpElemType.isInteger(8) &&
+ !(zpElemType.isInteger(16) && tensorUnsigned)) {
+ return op.emitOpError()
+ << "expect " << tensorName << "_zp of 0, got " << zp;
+ }
+ if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
+ return op.emitOpError() << "expect " << tensorName
+ << "_zp of 0 or 32768 for unsigned int16 "
+ << tensorName << ", got " << zp;
+ }
+ }
+
+ return success();
+}
+
#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
@@ -1714,6 +1767,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(RescaleOp, Input)
+ZERO_POINT_HELPER(RescaleOp, Output)
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2698,41 +2753,21 @@ LogicalResult RescaleOp::verify() {
return failure();
}
- auto input_zp = getInputZpAttr().getInt();
- if (input_zp != 0) {
- // only int8/uint8 and uint16 input can have non-zero input_zp
- if (!inputElementType.isInteger(8) &&
- !(inputElementType.isInteger(16) && getInputUnsigned())) {
- emitOpError("expect input_zp of 0, got ") << input_zp;
- return failure();
- }
- // input_zp must be either 0 or 32768 for uint16 input
- if (inputElementType.isInteger(16) && getInputUnsigned() &&
- input_zp != 32768) {
- emitOpError(
- "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
- << input_zp;
- return failure();
- }
- }
+ if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
+ .failed())
+ return failure();
- auto output_zp = getOutputZpAttr().getInt();
- if (output_zp != 0) {
- // only int8/uint8 and uint16 output can have non-zero output_zp
- if (!outputElementType.isInteger(8) &&
- !(outputElementType.isInteger(16) && getOutputUnsigned())) {
- emitOpError("expect output_zp of 0, got ") << output_zp;
- return failure();
- }
- // output_zp must be either 0 or 32768 for uint16 output
- if (outputElementType.isInteger(16) && getOutputUnsigned() &&
- output_zp != 32768) {
- emitOpError(
- "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
- << output_zp;
- return failure();
- }
- }
+ if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
+ .failed())
+ return failure();
+
+ FailureOr<int64_t> maybeIZp = getInputZeroPoint();
+ if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+ return failure();
auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
if (!multiplierType) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..4fba2e5dfde2b 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
template <>
void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
+ addValue(op.getInputZp());
+ addValue(op.getOutputZp());
addValue(op.getOutput());
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 77687b83e5e3c..842b50e804cbe 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -35,8 +35,11 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+
// expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = false, input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a3ed8c2805282..49d3e86e8fcd0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1149,9 +1149,11 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
// CHECK-DAG: [[BOUNDED:%.+]] = arith.minsi [[CMAX]], [[LOWER]]
// CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
// CHECK-DAG: linalg.yield [[TRUNC]]
- %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
- %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1182,7 +1184,9 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<2xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
// CHECK: return
return
@@ -1195,19 +1199,22 @@ func.func @rescale_i8_unsigned_output(%arg0 : tensor<2xi8>) -> () {
// CHECK-LABEL: @rescale_i8_dyn_batch
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
- %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
- %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
+ %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16>} : () -> tensor<1xi16>
+ %shift = "tosa.const"() {values = dense<15> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<17> : tensor<1xi8>} : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() {values = dense<22> : tensor<1xi8>} : () -> tensor<1xi8>
+
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
- %1 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>) -> tensor<?x2xi8>
+ %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, double_round = false, per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<?x2xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x2xi8>
return
}
@@ -1219,15 +1226,19 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
// CHECK-LABEL: @rescale_dyn
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]:
func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
+ %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32>} : () -> tensor<1xi32>
+ %shift = "tosa.const"() {values = dense<38> : tensor<1xi8>} : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
+ %output_zp = "tosa.const"() {values = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+
// CHECK: %[[C1:.+]] = arith.constant 1
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[C2:.+]] = arith.constant 2
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
- %multiplier = "tosa.const"() {values = dense<1376784203> : tensor<1xi32> } : () -> tensor<1xi32>
- %shift = "tosa.const"() {values = dense<38> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {double_round = true, input_zp = 0 : i32, output_zp = 0 : i32, per_channel = false, scale32 = true, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, double_round = true, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<1x?x?x32xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x?x?x32xi8>
+
return
}
@@ -1257,7 +1268,9 @@ func.func @rescale_i8_unsigned_input(%arg0 : tensor<2xi8>) -> () {
// CHECK: linalg.yield [[TRUNC]]
%multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
%shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
- %0 = tosa.rescale %arg0, %multiplier, %shift {input_zp = 17 : i32, output_zp = 22 : i32, scale32 = false, double_round = false, per_channel = false, input_unsigned = true, output_unsigned = false} : (tensor<2xi8>, tensor<1x...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/130340
More information about the Mlir-commits
mailing list