[Mlir-commits] [mlir] 107ca63 - [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (#155967)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Oct 13 19:10:23 PDT 2025


Author: ShivaChen
Date: 2025-10-14T10:10:18+08:00
New Revision: 107ca636352bcf31535924fa03df2852b70d0af3

URL: https://github.com/llvm/llvm-project/commit/107ca636352bcf31535924fa03df2852b70d0af3
DIFF: https://github.com/llvm/llvm-project/commit/107ca636352bcf31535924fa03df2852b70d0af3.diff

LOG: [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (#155967)

The shift, multiplier, inputZp, and outputZp can be either constant or
non-constant, depending on whether dynamic extension is enabled.

When these values are non-constant, they are added as inputs to
linalg::GenericOp, and corresponding affine maps are appended to the
indexingMaps.

The commit help to pass following Tosa conformance tests.
rescale_22x20_i32_outi8_sc0_rmS_pc0_iu0_ou0_dyn
rescale_31x18_i8_outi8_sc0_rmS_pc0_iu1_ou0_dyn
rescale_20x19_i16_outi8_sc0_rmS_pc0_iu1_ou0_dyn

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index a5336ed6bf2cd..00df14b1bdb77 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1392,6 +1392,137 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+static llvm::SmallVector<int8_t>
+convertToI8(const llvm::SmallVector<int32_t> &input) {
+  llvm::SmallVector<int8_t> output;
+  output.reserve(input.size());
+
+  for (auto v : llvm::map_range(
+           input, [](int32_t val) { return static_cast<int8_t>(val); })) {
+    output.push_back(v);
+  }
+  return output;
+}
+
+// The shift or multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift or multiplier is non-constant, add it as an input to
+// linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift or multiplier is constant, set 'constant' instead.
+static void setupLinalgGenericOpInputAndIndexingMap(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
+    bool isShift = false) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the
+    // values in a buffer.
+    if (values.size() == 1) {
+      IntegerAttr intAttr = isShift
+                                ? rewriter.getI8IntegerAttr(values.front())
+                                : rewriter.getI32IntegerAttr(values.front());
+      constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
+    } else {
+      auto elementType =
+          isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
+      auto tensorType = RankedTensorType::get(
+          {static_cast<int64_t>(values.size())}, elementType);
+      DenseIntElementsAttr EltAttr;
+      if (isShift)
+        EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
+      else
+        EltAttr = DenseIntElementsAttr::get(tensorType, values);
+      genericInputs.push_back(
+          arith::ConstantOp::create(rewriter, loc, EltAttr));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, exprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    // If we are not rescaling per-channel then we need to collapse 1xN to N
+    // and push broadcastMap.
+    auto operand = isShift ? op.getShift() : op.getMultiplier();
+    auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
+    if (tensorType && tensorType.hasStaticShape() &&
+        tensorType.getShape()[0] == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(operand);
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, exprs,
+                                            rewriter.getContext()));
+    }
+  }
+  arg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendZp(OpBuilder &builder, Type valueTy,
+                         FailureOr<int64_t> maybeZp, Location loc,
+                         ValueRange blockArgs, int64_t zpArg,
+                         bool isOutputZp = false) {
+  Value result;
+  const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+  const uint32_t attrBitwidth =
+      isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
+  auto extendType = builder.getIntegerType(attrBitwidth);
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[zpArg];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
+      // For ExtUIOp, the input must be signless.
+      // UnrealizedConversionCastOp will cast the input to signless type.
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            UnrealizedConversionCastOp::create(
+                builder, loc,
+                builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
+                .getResult(0);
+      }
+      if (zpTy.isUnsignedInteger()) {
+        return builder.create<arith::ExtUIOp>(loc, extendType, result);
+      } else {
+        return builder.create<arith::ExtSIOp>(loc, extendType, result);
+      }
+    }
+  } else {
+    return builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(extendType, *maybeZp));
+  }
+  return result;
+}
+
 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
 public:
   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
       }
     }
 
-    // The shift and multiplier values.
     DenseElementsAttr shiftElems;
-    if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant shift input values");
+    bool isShiftConstant = false;
+    if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
+      isShiftConstant = true;
 
     DenseElementsAttr multiplierElems;
-    if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
-      return rewriter.notifyMatchFailure(
-          op, "tosa.rescale requires constant multiplier input values");
-
-    llvm::SmallVector<int8_t> shiftValues =
-        llvm::to_vector(shiftElems.getValues<int8_t>());
-    // explicit cast is required here
-    llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
-        llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
-                        [](IntegerAttr attr) -> int32_t {
-                          return static_cast<int32_t>(attr.getInt());
-                        }));
-
-    // If we shift by more than the bitwidth, this just sets to 0.
-    for (int i = 0, s = multiplierValues.size(); i < s; i++) {
-      if (shiftValues[i] > 63) {
-        shiftValues[i] = 0;
-        multiplierValues[i] = 0;
+    bool isMultiplierConstant = false;
+    if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
+      isMultiplierConstant = true;
+
+    llvm::SmallVector<int32_t> shiftValues;
+    llvm::SmallVector<int32_t> multiplierValues;
+    bool doubleRound;
+
+    if (isMultiplierConstant && isShiftConstant) {
+      // explicit cast is required here
+      shiftValues = llvm::to_vector(llvm::map_range(
+          shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
+            return static_cast<int32_t>(attr.getInt());
+          }));
+      multiplierValues = llvm::to_vector(
+          llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
+                          [](IntegerAttr attr) -> int32_t {
+                            return static_cast<int32_t>(attr.getInt());
+                          }));
+
+      // If we shift by more than the bitwidth, this just sets to 0.
+      for (int i = 0, s = multiplierValues.size(); i < s; i++) {
+        if (shiftValues[i] > 63) {
+          shiftValues[i] = 0;
+          multiplierValues[i] = 0;
+        }
       }
-    }
+      // Double round only occurs if shift is greater than 31, check that this
+      // is ever true.
+      doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
+                    llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
+    } else
+      doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
 
-    // Double round only occurs if shift is greater than 31, check that this
-    // is ever true.
-
-    bool doubleRound =
-        op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
-        llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
     RoundingMode roundingMode =
         doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
 
@@ -1468,45 +1605,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
     // values in a buffer.
     Value multiplierConstant;
     int64_t multiplierArg = 0;
-    if (multiplierValues.size() == 1) {
-      multiplierConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> multiplierExprs{
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto multiplierType =
-          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
-                                rewriter.getI32Type());
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc,
-          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
-
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, multiplierExprs,
-                                            rewriter.getContext()));
-
-      multiplierArg = indexingMaps.size() - 1;
-    }
+    setupLinalgGenericOpInputAndIndexingMap(
+        rewriter, multiplierValues, genericInputs, indexingMaps,
+        isMultiplierConstant, op, multiplierConstant, multiplierArg);
 
     // If we are rescaling per-channel then we need to store the shift
     // values in a buffer.
     Value shiftConstant;
     int64_t shiftArg = 0;
-    if (shiftValues.size() == 1) {
-      shiftConstant = arith::ConstantOp::create(
-          rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
-    } else {
-      SmallVector<AffineExpr, 2> shiftExprs = {
-          rewriter.getAffineDimExpr(rank - 1)};
-      auto shiftType =
-          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
-                                rewriter.getIntegerType(8));
-      genericInputs.push_back(arith::ConstantOp::create(
-          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
-      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
-                                            /*symbolCount=*/0, shiftExprs,
-                                            rewriter.getContext()));
-      shiftArg = indexingMaps.size() - 1;
+    setupLinalgGenericOpInputAndIndexingMap(
+        rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
+        shiftConstant, shiftArg, true);
+
+    // broadcastMap = affine_map<(d0, d1) -> ()>
+    // It would affect as broadcast for scalar values in linalg::GenericOp.
+    AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
+    FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+    FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+    // The inputZp and outputZp may be either constant or non-constant,
+    // depending on whether dynamic extension is enabled.
+    // - If the zp's are non-constant, add them as an inputs to
+    // linalg::GenericOp by:
+    //     1. Pushing it into 'genericInputs'.
+    //     2. Appending a corresponding affine map to 'indexingMaps'.
+    // - If the zp's are constant, they would be generated as arith.constant.
+    int64_t iZpArg = 0;
+    if (failed(maybeIZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
+      indexingMaps.push_back(broadcastMap);
+      iZpArg = indexingMaps.size() - 1;
+    }
+    int64_t oZpArg = 0;
+    if (failed(maybeOZp)) {
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
+      indexingMaps.push_back(broadcastMap);
+      oZpArg = indexingMaps.size() - 1;
     }
 
     // Indexing maps for output values.
@@ -1526,36 +1661,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
           Type valueTy = value.getType();
 
           FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
-          if (failed(maybeIZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "input zero point cannot be statically determined");
-            return;
-          }
-
-          const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
-          // Extend zeropoint for sub-32bits widths.
-          const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
-          auto inputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
-                               *maybeIZp));
+          auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
+                                     nestedLoc, blockArgs, iZpArg);
 
           FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
-          if (failed(maybeOZp)) {
-            (void)rewriter.notifyMatchFailure(
-                op, "output zero point cannot be statically determined");
-            return;
-          };
+          auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
+                                      nestedLoc, blockArgs, oZpArg, true);
 
           IntegerType outIntType =
               cast<IntegerType>(blockArgs.back().getType());
           unsigned outBitWidth = outIntType.getWidth();
-          const int32_t outAttrBitwidth = 32;
           assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
-          auto outputZp = arith::ConstantOp::create(
-              nestedBuilder, loc,
-              IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
-                               *maybeOZp));
 
           Value multiplier = multiplierConstant ? multiplierConstant
                                                 : blockArgs[multiplierArg];

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index a7a73ae904042..780c25a9445a0 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1538,6 +1538,92 @@ func.func @unsupportedRescaleInexactRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>
 
 // -----
 
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const(%arg0 : tensor<2xi8>, %multiplier : tensor<1xi32>, %shift : tensor<1xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[MULTIPLIER:%.+]] = tensor.collapse_shape %arg1 [] : tensor<1xi32> into tensor<i32>
+  // CHECK: [[SHIFT:%.+]] = tensor.collapse_shape %arg2 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[MULTIPLIER]], [[SHIFT]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<i32>, tensor<i8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const_per_channel
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+// CHECK-SAME:  [[ARG1:%[0-9a-zA-Z_]*]]
+// CHECK-SAME:  [[ARG2:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const_per_channel(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xi8>, %output_zp : tensor<1xi8>) -> (tensor<2xi8>) {
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xi8> into tensor<i8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xi8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<i8>, tensor<i8>) outs([[INIT]] : tensor<2xi8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: i8, [[ARG4:%.*]]: i8, [[OUT:%.*]]: i8):
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extsi [[ARG3]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extsi [[ARG4]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c-128_i32 = arith.constant -128 : i32
+  // CHECK:    %c127_i32 = arith.constant 127 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c-128_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c127_i32, [[MAX]] : i32
+    %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<2xi8>
+  return %0 : tensor<2xi8>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> ()>
+// CHECK-LABEL: @rescale_no_const_per_channel_input_output_zp_ui8
+// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
+// CHECK-SAME:  [[ARG1:%[0-9a-zA-Z_]*]]
+// CHECK-SAME:  [[ARG2:%[0-9a-zA-Z_]*]]
+func.func @rescale_no_const_per_channel_input_output_zp_ui8(%arg0 : tensor<2xi8>, %arg1 : tensor<2xi32>, %arg2 : tensor<2xi8>, %input_zp : tensor<1xui8>, %output_zp : tensor<1xui8>) -> (tensor<2xui8>) {
+  // CHECK: [[INPUT_ZP:%.+]] = tensor.collapse_shape %arg3 [] : tensor<1xui8> into tensor<ui8>
+  // CHECK: [[OUTPUT_ZP:%.+]] = tensor.collapse_shape %arg4 [] : tensor<1xui8> into tensor<ui8>
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<2xui8>
+  // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]], [[ARG1]], [[ARG2]], [[INPUT_ZP]], [[OUTPUT_ZP]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<ui8>, tensor<ui8>) outs([[INIT]] : tensor<2xui8>) {
+  // CHECK:   ^bb0([[ARG0:%.*]]: i8, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i8, [[ARG3:%.*]]: ui8, [[ARG4:%.*]]: ui8, [[OUT:%.*]]: ui8):
+  // CHECK:    [[INPUT_ZP_I8:%.+]]  = builtin.unrealized_conversion_cast [[ARG3]] : ui8 to i8
+  // CHECK:    [[INPUT_ZP_I32:%.+]] = arith.extui [[INPUT_ZP_I8]] : i8 to i32
+  // CHECK:    [[OUTPUT_ZP_I8:%.+]]  = builtin.unrealized_conversion_cast [[ARG4]] : ui8 to i8
+  // CHECK:    [[OUTPUT_ZP_I32:%.+]] = arith.extui [[OUTPUT_ZP_I8]] : i8 to i32
+  // CHECK:    [[ARG0_I32:%.+]] = arith.extsi [[ARG0]] : i8 to i32
+  // CHECK:    [[TMP1:%.+]] = arith.subi [[ARG0_I32]], [[INPUT_ZP_I32]] : i32
+  // CHECK:    [[TMP2:%.+]] = tosa.apply_scale [[TMP1]], [[ARG1]], [[ARG2]] {rounding_mode = DOUBLE_ROUND} : (i32, i32, i8) -> i32
+  // CHECK:    [[TMP3:%.+]] = arith.addi [[TMP2]], [[OUTPUT_ZP_I32]] : i32
+  // CHECK:    %c0_i32 = arith.constant 0 : i32
+  // CHECK:    %c255_i32 = arith.constant 255 : i32
+  // CHECK:    [[MAX:%.+]] = arith.maxsi %c0_i32, [[TMP3]] : i32
+  // CHECK:    [[MIN:%.+]] = arith.minsi %c255_i32, [[MAX]] : i32
+    %0 = tosa.rescale %arg0, %arg1, %arg2, %input_zp, %output_zp {scale32 = true, rounding_mode = DOUBLE_ROUND, per_channel = true, input_unsigned = false, output_unsigned = true} : (tensor<2xi8>, tensor<2xi32>, tensor<2xi8>, tensor<1xui8>, tensor<1xui8>) -> tensor<2xui8>
+  return %0 : tensor<2xui8>
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
 
 // CHECK-LABEL: @reverse


        


More information about the Mlir-commits mailing list