[Mlir-commits] [mlir] [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (PR #155967)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Aug 31 22:21:15 PDT 2025


https://github.com/ShivaChen updated https://github.com/llvm/llvm-project/pull/155967

>From 0baeef537fc7a6545ff86fcc0f40722beca77c1a Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Sat, 30 Aug 2025 03:27:50 +0100
Subject: [PATCH 1/2] [mlir][tosa] Support RescaleOp with dynamic extension in
 TosaToLinalg

The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 349 +++++++++++++-----
 .../TosaToLinalg/tosa-to-linalg.mlir          |  56 +++
 2 files changed, 317 insertions(+), 88 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index d0a431b1caa7f..2e3841ec85883 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
+    int64_t &multiplierArg) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> multiplierExprs{
+      rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(arith::ConstantOp::create(
+          rewriter, loc,
+          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    // If we are not rescaling per-channel then we need to collapse 1xN to N
+    // and push broadcastMap.
+    auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
+    if (tensorType && tensorType.hasStaticShape() &&
+        tensorType.getShape()[0] == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getMultiplier());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  }
+  multiplierArg = indexingMaps.size() - 1;
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+    PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
+    int64_t &shiftArg) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(arith::ConstantOp::create(
+          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    // If we are not rescaling per-channel then we need to collapse 1xN to N
+    // and push broadcastMap.
+    auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
+    if (tensorType && tensorType.hasStaticShape() &&
+        tensorType.getShape()[0] == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getShift(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getShift());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  }
+  shiftArg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+                              FailureOr<int64_t> maybeZp, Location loc,
+                              ValueRange blockArgs, int64_t iZpArg) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[iZpArg];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    }
+  } else {
+    const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+    // Extend zeropoint for sub-32bits widths.
+    const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+                            FailureOr<int64_t> maybeZp, Location loc,
+                            ValueRange blockArgs, int64_t oZpArg) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[oZpArg];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+      result =
+          builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+    }
+  } else {
+    const int32_t attrBitwidth = 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1376,40 +1569,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
       }
     }
 
-    // The shift and multiplier values.
     DenseElementsAttr shiftElems;
-    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant shift input values");
+    bool isShiftConstant = false;
+    if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      isShiftConstant = true;
 
     DenseElementsAttr multiplierElems;
-    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant multiplier input values");
-
-    llvm::SmallVector<int8_t> shiftValues =
-        llvm::to_vector(shiftElems.getValues<int8_t>());
-    // explicit cast is required here
-    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
-        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
-                        [](IntegerAttr attr) -> int32_t {
-                          return static_cast<int32_t>(attr.getInt());
-                        }));
-
-    // If we shift by more than the bitwidth, this just sets to 0.
-    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
-      if (shiftValues[i] > 63) {
-        shiftValues[i] = 0;
-        multiplierValues[i] = 0;
+    bool isMultiplierConstant = false;
+    if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      isMultiplierConstant = true;
+
+    llvm::SmallVector<int8_t> shiftValues;
+    llvm::SmallVector<int32_t> multiplierValues;
+    bool doubleRound;
+
+    if (isMultiplierConstant && isShiftConstant) {
+      shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
+      // explicit cast is required here
+      multiplierValues = llvm::to_vector(
+          llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                          [](IntegerAttr attr) -> int32_t {
+                            return static_cast<int32_t>(attr.getInt());
+                          }));
+
+      // If we shift by more than the bitwidth, this just sets to 0.
+      for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+        if (shiftValues[i] > 63) {
+          shiftValues[i] = 0;
+          multiplierValues[i] = 0;
+        }
       }
-    }
+      // Double round only occurs if shift is greater than 31, check that this
+      // is ever true.
+      doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+                    llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    } else
+      doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
 
-    // Double round only occurs if shift is greater than 31, check that this
-    // is ever true.
-
-    bool doubleRound =
-        op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
-        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
     RoundingMode roundingMode =
         doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
 
@@ -1421,45 +1617,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // values in a buffer.
     Value multiplierConstant;
     int64_t multiplierArg = 0;
-    if (multiplierValues.size() == 1) {
-      multiplierConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> multiplierExprs{
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto multiplierType =
-          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
-                                rewriter.getI32Type());
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc,
-          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, multiplierExprs,
-                                            rewriter.getContext()));
-
-      multiplierArg = indexingMaps.size() - 1;
-    }
+    setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+        rewriter, multiplierValues, genericInputs, indexingMaps,
+        isMultiplierConstant, op, multiplierConstant, multiplierArg);
 
     // If we are rescaling per-channel then we need to store the shift
     // values in a buffer.
     Value shiftConstant;
     int64_t shiftArg = 0;
-    if (shiftValues.size() == 1) {
-      shiftConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> shiftExprs = {
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto shiftType =
-          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
-                                rewriter.getIntegerType(8));
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, shiftExprs,
-                                            rewriter.getContext()));
-      shiftArg = indexingMaps.size() - 1;
+    setupLinalgGenericOpInputAndIndexingMapForShift(
+        rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+        shiftConstant, shiftArg);
+
+    // broadcastMap = affine_map<(d0, d1) -> ()>
+    // It would affect as broadcast for scalar values in linalg::GenericOp.
+    AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+    FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+    FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+    // The inputZp and outputZp may be either constant or non-constant,
+    // depending on whether dynamic extension is enabled.
+    // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
+    //     1. Pushing it into 'genericInputs'.
+    //     2. Appending a corresponding affine map to 'indexingMaps'.
+    int64_t iZpArg = 0;
+    if (failed(maybeIZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+      indexingMaps.push_back(broadcastMap);
+      iZpArg = indexingMaps.size() - 1;
+    }
+    int64_t oZpArg = 0;
+    if (failed(maybeOZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+      indexingMaps.push_back(broadcastMap);
+      oZpArg = indexingMaps.size() - 1;
     }
 
     // Indexing maps for output values.
@@ -1479,39 +1671,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Type valueTy = value.getType();
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
-          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
-          // Extend zeropoint for sub-32bits widths.
-          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
-          auto inputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
-                               *maybeIZp));
+          auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
+                                          nestedLoc, blockArgs, iZpArg);
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
+          auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
+                                         nestedLoc, blockArgs, oZpArg);
 
           IntegerType outIntType =
               cast<IntegerType>(blockArgs.back().getType());
           unsigned outBitWidth = outIntType.getWidth();
-          const int32_t outAttrBitwidth = 32;
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
-          auto outputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
-                               *maybeOZp));
 
-          Value multiplier = multiplierConstant ? multiplierConstant
-                                                : blockArgs[multiplierArg];
+          Value multiplier =
+              multiplierConstant ? multiplierConstant : blockArgs[multiplierArg];
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
           if (valueTy.isUnsignedInteger()) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3fc513f823a1a..b97d7bebec1e9 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1478,6 +1478,62 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+  // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const_per_channel
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+// CHECK-SAME:  [[ARG1:%[0-9a-zA-Z_]*]]
+// CHECK-SAME:  [[ARG2:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+    %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: @reverse

>From 41fc88d845dbc07c8704c0f5102f721cc827bb25 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Mon, 1 Sep 2025 06:20:05 +0100
Subject: [PATCH 2/2] Fix clang-format

---
 mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 2e3841ec85883..ff660109f3372 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1683,8 +1683,8 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           unsigned outBitWidth = outIntType.getWidth();
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
 
-          Value multiplier =
-              multiplierConstant ? multiplierConstant : blockArgs[multiplierArg];
+          Value multiplier = multiplierConstant ? multiplierConstant
+                                                : blockArgs[multiplierArg];
           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
 
           if (valueTy.isUnsignedInteger()) {



More information about the Mlir-commits mailing list