[Mlir-commits] [mlir] [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (PR #155967)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 28 20:24:24 PDT 2025
https://github.com/ShivaChen created https://github.com/llvm/llvm-project/pull/155967
The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled.
When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.
>From 6ec4bcc4849f720c9edbb85af3591b68962104d5 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 11 Aug 2025 04:09:29 +0100
Subject: [PATCH] [mlir][tosa] Support RescaleOp with dynamic extension in
TosaToLinalg
The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.
When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 345 +++++++++++++-----
.../TosaToLinalg/tosa-to-linalg.mlir | 28 ++
2 files changed, 277 insertions(+), 96 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 73046e0da361a..cc1289f397dff 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1342,6 +1342,186 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> multiplierExprs{
+ rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the multiplier
+ // values in a buffer.
+ if (multiplierValues.size() == 1) {
+ multiplierConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ } else {
+ auto multiplierType =
+ RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+ rewriter.getI32Type());
+ genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+ loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ if (op.getMultiplier().getType().getRank() == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getMultiplier());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ }
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+ PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the shift
+ // values in a buffer.
+ if (shiftValues.size() == 1) {
+ shiftConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ } else {
+ auto shiftType =
+ RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+ rewriter.getIntegerType(8));
+ genericInputs.push_back(rewriter.create<arith::ConstantOp>(
+ loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ if (op.getShift().getType().getRank() == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getShift(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getShift());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ }
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[3];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ }
+ } else {
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ // Extend zeropoint for sub-32bits widths.
+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[4];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+ result =
+ builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+ }
+ } else {
+ const int32_t attrBitwidth = 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1375,41 +1555,45 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// The shift and multiplier values.
DenseElementsAttr shiftElems;
- if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant shift input values");
+ bool isShiftConstant = false;
+ if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+ isShiftConstant = true;
DenseElementsAttr multiplierElems;
- if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
- return rewriter.notifyMatchFailure(
- op, "tosa.rescale requires constant multiplier input values");
-
- llvm::SmallVector<int8_t> shiftValues =
- llvm::to_vector(shiftElems.getValues<int8_t>());
- // explicit cast is required here
- llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
- llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
- [](IntegerAttr attr) -> int32_t {
- return static_cast<int32_t>(attr.getInt());
- }));
-
- // If we shift by more than the bitwidth, this just sets to 0.
- for (int i = 0, s = multiplierValues.size(); i < s; i++) {
- if (shiftValues[i] > 63) {
- shiftValues[i] = 0;
- multiplierValues[i] = 0;
+ bool isMultiplierConstant = false;
+ if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+ isMultiplierConstant = true;
+
+ llvm::SmallVector<int8_t> shiftValues;
+ llvm::SmallVector<int32_t> multiplierValues;
+ StringAttr roundingMode;
+ bool doubleRound;
+
+ if (isMultiplierConstant && isShiftConstant) {
+ shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+ // explicit cast is required here
+ multiplierValues = llvm::to_vector(
+ llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+ [](IntegerAttr attr) -> int32_t {
+ return static_cast<int32_t>(attr.getInt());
+ }));
+
+ // If we shift by more than the bitwidth, this just sets to 0.
+ for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+ if (shiftValues[i] > 63) {
+ shiftValues[i] = 0;
+ multiplierValues[i] = 0;
+ }
}
- }
-
- // Double round only occurs if shift is greater than 31, check that this
- // is ever true.
+ // Double round only occurs if shift is greater than 31, check that this
+ // is ever true.
+ doubleRound = op.getRoundingMode() == "DOUBLE_ROUND" &&
+ llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+ } else
+ doubleRound = op.getRoundingMode() == "DOUBLE_ROUND";
- bool doubleRound =
- op.getRoundingMode() == "DOUBLE_ROUND" &&
- llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
- StringAttr roundingMode = doubleRound
- ? rewriter.getStringAttr("DOUBLE_ROUND")
- : rewriter.getStringAttr("SINGLE_ROUND");
+ roundingMode = doubleRound ? rewriter.getStringAttr("DOUBLE_ROUND")
+ : rewriter.getStringAttr("SINGLE_ROUND");
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(rank)};
@@ -1418,46 +1602,35 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
Value multiplierConstant;
- int64_t multiplierArg = 0;
- if (multiplierValues.size() == 1) {
- multiplierConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
- } else {
- SmallVector<AffineExpr, 2> multiplierExprs{
- rewriter.getAffineDimExpr(rank - 1)};
- auto multiplierType =
- RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
- rewriter.getI32Type());
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc,
- DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, multiplierExprs,
- rewriter.getContext()));
-
- multiplierArg = indexingMaps.size() - 1;
- }
-
+ setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ rewriter, multiplierValues, genericInputs, indexingMaps,
+ isMultiplierConstant, op, multiplierConstant);
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
- int64_t shiftArg = 0;
- if (shiftValues.size() == 1) {
- shiftConstant = arith::ConstantOp::create(
- rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
- } else {
- SmallVector<AffineExpr, 2> shiftExprs = {
- rewriter.getAffineDimExpr(rank - 1)};
- auto shiftType =
- RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
- rewriter.getIntegerType(8));
- genericInputs.push_back(arith::ConstantOp::create(
- rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
- indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
- /*symbolCount=*/0, shiftExprs,
- rewriter.getContext()));
- shiftArg = indexingMaps.size() - 1;
+ setupLinalgGenericOpInputAndIndexingMapForShift(
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+ shiftConstant);
+
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+ FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+ FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+ // The inputZp and outputZp may be either constant or non-constant,
+ // depending on whether dynamic extension is enabled.
+ // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+ // 1. Pushing it into 'genericInputs'.
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
+ if (failed(maybeIZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+ indexingMaps.push_back(broadcastMap);
+ }
+ if (failed(maybeOZp)) {
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+ indexingMaps.push_back(broadcastMap);
}
// Indexing maps for output values.
@@ -1477,40 +1650,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
- if (failed(maybeIZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "input zero point cannot be statically determined");
- return;
- }
-
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
- // Extend zeropoint for sub-32bits widths.
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
- auto inputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
- *maybeIZp));
-
+ auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+ nestedLoc, blockArgs);
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
- if (failed(maybeOZp)) {
- (void)rewriter.notifyMatchFailure(
- op, "output zero point cannot be statically determined");
- return;
- };
+ auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+ nestedLoc, blockArgs);
IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
- const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
- auto outputZp = arith::ConstantOp::create(
- nestedBuilder, loc,
- IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
- *maybeOZp));
-
- Value multiplier = multiplierConstant ? multiplierConstant
- : blockArgs[multiplierArg];
- Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
+
+ Value multiplier =
+ multiplierConstant ? multiplierConstant : blockArgs[1];
+ Value shift = shiftConstant ? shiftConstant : blockArgs[2];
if (valueTy.isUnsignedInteger()) {
value = UnrealizedConversionCastOp::create(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aee0caa91043d..8313173e1fec9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,34 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
// -----
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+ // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+ // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+ // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+ // CHECK: ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+ // CHECK: [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+ // CHECK: [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+ // CHECK: [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+ // CHECK: [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+ // CHECK: [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = "DOUBLE_ROUND"} : (i32, i32, i8) -> i32
+ // CHECK: [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+ // CHECK: %c-128_i32 = arith.constant -128 : i32
+ // CHECK: %c127_i32 = arith.constant 127 : i32
+ // CHECK: [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+ // CHECK: [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "DOUBLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+ return %0 : tensor<2xi8>
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @reverse
More information about the Mlir-commits
mailing list