[Mlir-commits] [mlir] [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (PR #155967)
Luke Hutton
llvmlistbot at llvm.org
Mon Sep 22 08:25:20 PDT 2025
================
@@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+ Location loc) {
+ SmallVector<ReassociationExprs, 1> reassociation;
+ // Create the collapsed type
+ auto inputType = cast<RankedTensorType>(input.getType());
+ auto elemType = inputType.getElementType();
+ auto collapsedType = RankedTensorType::get({}, elemType);
+ // Emit the collapse op
+ return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+ reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
+ int64_t &multiplierArg) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> multiplierExprs{
+ rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the multiplier
+ // values in a buffer.
+ if (multiplierValues.size() == 1) {
+ multiplierConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+ } else {
+ auto multiplierType =
+ RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+ rewriter.getI32Type());
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc,
+ DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getMultiplier());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, multiplierExprs,
+ rewriter.getContext()));
+ }
+ }
+ multiplierArg = indexingMaps.size() - 1;
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+// 1. Pushing it into 'genericInputs'.
+// 2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+ PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+ SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
+ int64_t &shiftArg) {
+
+ auto loc = op.getLoc();
+ auto inputTy = cast<ShapedType>(op.getInput().getType());
+ unsigned rank = inputTy.getRank();
+ SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+ if (isConstant) {
+ // If we are rescaling per-channel then we need to store the shift
+ // values in a buffer.
+ if (shiftValues.size() == 1) {
+ shiftConstant = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+ } else {
+ auto shiftType =
+ RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+ rewriter.getIntegerType(8));
+ genericInputs.push_back(arith::ConstantOp::create(
+ rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ } else {
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
+ // and push broadcastMap.
+ auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
+ if (tensorType && tensorType.hasStaticShape() &&
+ tensorType.getShape()[0] == 1) {
+ // broadcastMap = affine_map<(d0, d1) -> ()>
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
+ AffineMap broadcastMap =
+ AffineMap::get(rank, 0, {}, rewriter.getContext());
+ genericInputs.push_back(
+ collapse1xNTensorToN(rewriter, op.getShift(), loc));
+ indexingMaps.push_back(broadcastMap);
+ } else {
+ genericInputs.push_back(op.getShift());
+ indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+ /*symbolCount=*/0, shiftExprs,
+ rewriter.getContext()));
+ }
+ }
+ shiftArg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t iZpArg) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[iZpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ }
+ } else {
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+ // Extend zeropoint for sub-32bits widths.
+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+ FailureOr<int64_t> maybeZp, Location loc,
+ ValueRange blockArgs, int64_t oZpArg) {
+ Value result;
+ // The Zp value can be either constant or non-constant, depending on
+ // whether dynamic extension is enabled.
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+ // be passed as an input to linalg::GenericOp.
+ if (failed(maybeZp)) {
+ result = blockArgs[oZpArg];
+ auto zpTy = result.getType();
+ if (zpTy.getIntOrFloatBitWidth() < 32) {
+ if (zpTy.isUnsignedInteger()) {
+ result =
+ builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+ } else {
+ result =
+ builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+ }
+ } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+ result =
+ builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+ }
+ } else {
+ const int32_t attrBitwidth = 32;
+ result = builder.create<arith::ConstantOp>(
+ loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+ }
+ return result;
----------------
lhutton1 wrote:
nit: based on the logic here, I think some of the nesting could be removed by returning `result` directly?
https://github.com/llvm/llvm-project/pull/155967
More information about the Mlir-commits
mailing list