[Mlir-commits] [mlir] [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (PR #155967)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Aug 31 22:21:15 PDT 2025
https://github.com/ShivaChen updated https://github.com/llvm/llvm-project/pull/155967
>From 0baeef537fc7a6545ff86fcc0f40722beca77c1a Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Sat, 30 Aug 2025 03:27:50 +0100
Subject: [PATCH 1/2] [mlir][tosa] Support RescaleOp with dynamic extension in
TosaToLinalg
The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.
When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 349 +++++++++++++-----
.../TosaToLinalg/tosa-to-linalg.mlir | 56 +++
2 files changed, 317 insertions(+), 88 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d0a431b1caa7f..2e3841ec85883 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
+ int64_t &multiplierArg) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> multiplierExprs{
+ rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the multiplier
+ // values in a buffer.
+ if (multiplierValues.size() == 1) {
+ multiplierConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ } else {
+ auto multiplierType =
+ RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+ rewriter.getI32Type());
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc,
+ DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getMultiplier());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ }
+ multiplierArg = indexingMaps.size() - 1;
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+ PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
+ int64_t &shiftArg) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the shift
+ // values in a buffer.
+ if (shiftValues.size() == 1) {
+ shiftConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ } else {
+ auto shiftType =
+ RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+ rewriter.getIntegerType(8));
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getShift(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getShift());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ }
+ shiftArg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t iZpArg) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[iZpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ }
+ } else {
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ // Extend zeropoint for sub-32bits widths.
+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t oZpArg) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[oZpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+ result =
+ builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+ }
+ } else {
+ const int32_t attrBitwidth = 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1376,40 +1569,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
}
- // The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int8_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+ // explicit cast is required here
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
-
- bool doubleRound =
- op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
@@ -1421,45 +1617,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
+ setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMapForShift(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant, shiftArg);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ int64_t iZpArg = 0;
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ iZpArg = indexingMaps.size() - 1;
+ }
+ int64_t oZpArg = 0;
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
+ oZpArg = indexingMaps.size() - 1;
}
// Indexing maps for output values.
@@ -1479,39 +1671,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
+ auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs, iZpArg);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs, oZpArg);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
- Value multiplier = multiplierConstant ? multiplierConstant
- : blockArgs[multiplierArg];
+ Value multiplier =
+ multiplierConstant ? multiplierConstant : blockArgs[multiplierArg];
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3fc513f823a1a..b97d7bebec1e9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,62 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
// -----
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+ // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const_per_channel
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+// CHECK-SAME: [[ARG1:%[0-9a-zA-Z_]*]]
+// CHECK-SAME: [[ARG2:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @reverse
>From 41fc88d845dbc07c8704c0f5102f721cc827bb25 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 1 Sep 2025 06:20:05 +0100
Subject: [PATCH 2/2] Fix clang-format
---
mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2e3841ec85883..ff660109f3372 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1683,8 +1683,8 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
unsigned outBitWidth = outIntType.getWidth();
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- Value multiplier =
- multiplierConstant ? multiplierConstant : blockArgs[multiplierArg];
+ Value multiplier = multiplierConstant ? multiplierConstant
+ : blockArgs[multiplierArg];
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
if (valueTy.isUnsignedInteger()) {
More information about the Mlir-commits
mailing list