[Mlir-commits] [mlir] 107ca63 - [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (#155967)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 13 19:10:23 PDT 2025
Author: ShivaChen
Date: 2025-10-14T10:10:18+08:00
New Revision: 107ca636352bcf31535924fa03df2852b70d0af3
URL: https://github.com/llvm/llvm-project/commit/107ca636352bcf31535924fa03df2852b70d0af3
DIFF: https://github.com/llvm/llvm-project/commit/107ca636352bcf31535924fa03df2852b70d0af3.diff
LOG: [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (#155967)
The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.
When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.
The commit help to pass following Tosa conformance tests.
rescale_22x20_i32_outi8_sc0_rmS_pc0_iu0_ou0_dyn
rescale_31x18_i8_outi8_sc0_rmS_pc0_iu1_ou0_dyn
rescale_20x19_i16_outi8_sc0_rmS_pc0_iu1_ou0_dyn
Added:
Modified:
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a5336ed6bf2cd..00df14b1bdb77 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1392,6 +1392,137 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+static llvm::SmallVector<int8_t>
+convertToI8(const llvm::SmallVector<int32_t> &input) {
+ llvm::SmallVector<int8_t> output;
+ output.reserve(input.size());
+
+ for (auto v : llvm::map_range(
+ input, [](int32_t val) { return static_cast<int8_t>(val); })) {
+ output.push_back(v);
+ }
+ return output;
+}
+
+// The shift or multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift or multiplier is non-constant, add it as an input to
+// linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift or multiplier is constant, set 'constant' instead.
+static void setupLinalgGenericOpInputAndIndexingMap(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
+ bool isShift = false) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the
+ // values in a buffer.
+ if (values.size() == 1) {
+ IntegerAttr intAttr = isShift
+ ? rewriter.getI8IntegerAttr(values.front())
+ : rewriter.getI32IntegerAttr(values.front());
+ constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+ } else {
+ auto elementType =
+ isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
+ auto tensorType = RankedTensorType::get(
+ {static_cast<int64_t>(values.size())}, elementType);
+ DenseIntElementsAttr EltAttr;
+ if (isShift)
+ EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
+ else
+ EltAttr = DenseIntElementsAttr::get(tensorType, values);
+ genericInputs.push_back(
+ arith::ConstantOp::create(rewriter, loc, EltAttr));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto operand = isShift ? op.getShift() : op.getMultiplier();
+ auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(operand);
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, exprs,
+ rewriter.getContext()));
+ }
+ }
+ arg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t zpArg,
+ bool isOutputZp = false) {
+ Value result;
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ const uint32_t attrBitwidth =
+ isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
+ auto extendType = builder.getIntegerType(attrBitwidth);
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[zpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
+ // For ExtUIOp, the input must be signless.
+ // UnrealizedConversionCastOp will cast the input to signless type.
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ UnrealizedConversionCastOp::create(
+ builder, loc,
+ builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
+ .getResult(0);
+ }
+ if (zpTy.isUnsignedInteger()) {
+ return builder.create<arith::ExtUIOp>(loc, extendType, result);
+ } else {
+ return builder.create<arith::ExtSIOp>(loc, extendType, result);
+ }
+ }
+ } else {
+ return builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(extendType, *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
}
- // The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int32_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ // explicit cast is required here
+ shiftValues = llvm::to_vector(llvm::map_range(
+ shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
-
- bool doubleRound =
- op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
@@ -1468,45 +1605,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMap(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant, shiftArg, true);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp's are non-constant, add them as an inputs to
+ // linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ // - If the zp's are constant, they would be generated as arith.constant.
+ int64_t iZpArg = 0;
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ iZpArg = indexingMaps.size() - 1;
+ }
+ int64_t oZpArg = 0;
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
+ oZpArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
@@ -1526,36 +1661,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs, iZpArg);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs, oZpArg, true);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a7a73ae904042..780c25a9445a0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1538,6 +1538,92 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
// -----
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+ // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const_per_channel
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]
+// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const_per_channel_input_output_zp_ui8
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]
+// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const_per_channel_input_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xui8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) {
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xui8> into tensor<ui8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor<ui8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<ui8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: ui8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8):
+ // CHECK: [[INPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG3]] : ui8 to i8
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I8:%.+]] = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extui [[OUTPUT_ZP_I8]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c0_i32 = arith.constant 0 : i32
+ // CHECK: %c255_i32 = arith.constant 255 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xui8>, tensor<1xui8>) -> tensor<2xui8>
+ return %0 : tensor<2xui8>
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @reverse
More information about the Mlir-commits
mailing list