[Mlir-commits] [mlir] [mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (PR #155967)

Luke Hutton llvmlistbot at llvm.org
Mon Sep 22 08:25:20 PDT 2025


================
@@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
   }
 };
 
+// Collapse tensor<1xiN> into tensor<iN>
+// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
+static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
+                                  Location loc) {
+  SmallVector<ReassociationExprs, 1> reassociation;
+  // Create the collapsed type
+  auto inputType = cast<RankedTensorType>(input.getType());
+  auto elemType = inputType.getElementType();
+  auto collapsedType = RankedTensorType::get({}, elemType);
+  // Emit the collapse op
+  return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
+                                                  reassociation);
+}
+
+// The multiplier may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
+// by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the multiplier is constant, set 'multiplierConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
+    PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
+    int64_t &multiplierArg) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> multiplierExprs{
+      rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the multiplier
+    // values in a buffer.
+    if (multiplierValues.size() == 1) {
+      multiplierConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
+    } else {
+      auto multiplierType =
+          RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
+                                rewriter.getI32Type());
+      genericInputs.push_back(arith::ConstantOp::create(
+          rewriter, loc,
+          DenseIntElementsAttr::get(multiplierType, multiplierValues)));
+
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    // If we are not rescaling per-channel then we need to collapse 1xN to N
+    // and push broadcastMap.
+    auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
+    if (tensorType && tensorType.hasStaticShape() &&
+        tensorType.getShape()[0] == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getMultiplier());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, multiplierExprs,
+                                            rewriter.getContext()));
+    }
+  }
+  multiplierArg = indexingMaps.size() - 1;
+}
+
+// The shift may be either constant or non-constant, depending on
+// whether dynamic extension is enabled.
+// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
+//     1. Pushing it into 'genericInputs'.
+//     2. Appending a corresponding affine map to 'indexingMaps'.
+// - If the shift is constant, set 'shiftConstant' instead.
+static void setupLinalgGenericOpInputAndIndexingMapForShift(
+    PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
+    SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
+    bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
+    int64_t &shiftArg) {
+
+  auto loc = op.getLoc();
+  auto inputTy = cast<ShapedType>(op.getInput().getType());
+  unsigned rank = inputTy.getRank();
+  SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
+
+  if (isConstant) {
+    // If we are rescaling per-channel then we need to store the shift
+    // values in a buffer.
+    if (shiftValues.size() == 1) {
+      shiftConstant = rewriter.create<arith::ConstantOp>(
+          loc, rewriter.getI8IntegerAttr(shiftValues.front()));
+    } else {
+      auto shiftType =
+          RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
+                                rewriter.getIntegerType(8));
+      genericInputs.push_back(arith::ConstantOp::create(
+          rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  } else {
+    // If we are not rescaling per-channel then we need to collapse 1xN to N
+    // and push broadcastMap.
+    auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
+    if (tensorType && tensorType.hasStaticShape() &&
+        tensorType.getShape()[0] == 1) {
+      // broadcastMap = affine_map<(d0, d1) -> ()>
+      // It would affect as broadcast for scalar values in linalg::GenericOp.
+      AffineMap broadcastMap =
+          AffineMap::get(rank, 0, {}, rewriter.getContext());
+      genericInputs.push_back(
+          collapse1xNTensorToN(rewriter, op.getShift(), loc));
+      indexingMaps.push_back(broadcastMap);
+    } else {
+      genericInputs.push_back(op.getShift());
+      indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
+                                            /*symbolCount=*/0, shiftExprs,
+                                            rewriter.getContext()));
+    }
+  }
+  shiftArg = indexingMaps.size() - 1;
+}
+
+// Return the extended Zp to be used in subsequent arithmetic operations.
+static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
+                              FailureOr<int64_t> maybeZp, Location loc,
+                              ValueRange blockArgs, int64_t iZpArg) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[iZpArg];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    }
+  } else {
+    const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
+    // Extend zeropoint for sub-32bits widths.
+    const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
+}
+
+// Return the i32 outputZp to be used in subsequent arithmetic operations.
+static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
+                            FailureOr<int64_t> maybeZp, Location loc,
+                            ValueRange blockArgs, int64_t oZpArg) {
+  Value result;
+  // The Zp value can be either constant or non-constant, depending on
+  // whether dynamic extension is enabled.
+  // If 'maybeZp' fails, it indicates that Zp is non-constant and will
+  // be passed as an input to linalg::GenericOp.
+  if (failed(maybeZp)) {
+    result = blockArgs[oZpArg];
+    auto zpTy = result.getType();
+    if (zpTy.getIntOrFloatBitWidth() < 32) {
+      if (zpTy.isUnsignedInteger()) {
+        result =
+            builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
+      } else {
+        result =
+            builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
+      }
+    } else if (zpTy.getIntOrFloatBitWidth() > 32) {
+      result =
+          builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
+    }
+  } else {
+    const int32_t attrBitwidth = 32;
+    result = builder.create<arith::ConstantOp>(
+        loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
+  }
+  return result;
----------------
lhutton1 wrote:

nit: based on the logic here, I think some of the nesting could be removed by returning `result` directly?

https://github.com/llvm/llvm-project/pull/155967


More information about the Mlir-commits mailing list