[Mlir-commits] [mlir] [mlir][tosa] Convert RESCALE op multiplier and shift from attributes to inputs (PR #129720)

Luke Hutton llvmlistbot at llvm.org
Wed Mar 5 03:19:48 PST 2025


================
@@ -2469,6 +2468,139 @@ LogicalResult TransposeConv2DOp::verify() {
   return success();
 }
 
+LogicalResult RescaleOp::verify() {
+  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
+  if (!inputType) {
+    emitOpError("expect shaped tensor for input, got ") << getInput().getType();
+    return failure();
+  }
+  auto inputElementType = inputType.getElementType();
+  if (auto inputQType =
+          llvm::dyn_cast<quant::QuantizedType>(inputElementType)) {
+    inputElementType = inputQType.getStorageType();
+  }
+  if (!mlir::isa<IntegerType>(inputElementType)) {
+    emitOpError("expect input to have integer element type, got ")
+        << inputElementType;
+    return failure();
+  }
+  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
+  if (!outputType) {
+    emitOpError("expect shaped tensor for output, got ")
+        << getOutput().getType();
+    return failure();
+  }
+  auto outputElementType = outputType.getElementType();
+  if (auto outputQType =
+          llvm::dyn_cast<quant::QuantizedType>(outputElementType)) {
+    outputElementType = outputQType.getStorageType();
+  }
+  if (!mlir::isa<IntegerType>(outputElementType)) {
+    emitOpError("expect output to have integer element type, got ")
+        << outputElementType;
+    return failure();
+  }
+  auto input_zp = getInputZpAttr().getInt();
+  if (input_zp != 0) {
+    // only int8/uint8 and uint16 input can have non-zero input_zp
+    if (!inputElementType.isInteger(8) &&
+        !(inputElementType.isInteger(16) && getInputUnsigned())) {
+      emitOpError("expect input_zp of 0, got ") << input_zp;
+      return failure();
+    }
+    // input_zp must be either 0 or 32768 for uint16 input
+    if (inputElementType.isInteger(16) && getInputUnsigned() &&
+        input_zp != 32768) {
+      emitOpError(
+          "expect input_zp of 0 or 32768 for unsigned int16 input, got ")
+          << input_zp;
+      return failure();
+    }
+  }
+
+  auto output_zp = getOutputZpAttr().getInt();
+  if (output_zp != 0) {
+    // only int8/uint8 and uint16 output can have non-zero output_zp
+    if (!outputElementType.isInteger(8) &&
+        !(outputElementType.isInteger(16) && getOutputUnsigned())) {
+      emitOpError("expect output_zp of 0, got ") << output_zp;
+      return failure();
+    }
+    // output_zp must be either 0 or 32768 for uint16 output
+    if (outputElementType.isInteger(16) && getOutputUnsigned() &&
+        output_zp != 32768) {
+      emitOpError(
+          "expect output_zp of 0 or 32768 for unsigned int16 output, got ")
+          << output_zp;
+      return failure();
+    }
+  }
+
+  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
+  if (!multiplierType) {
+    emitOpError("expect shaped tensor for multiplier, got ")
+        << getMultiplier().getType();
+    return failure();
+  }
+  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
+  if (!shiftType) {
+    emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
+    return failure();
+  }
+
+  // multiplier element type must be i32 for scale32 = true
+  if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
+    emitOpError("expect i32 element type for multiplier for scale32=true, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier element type must be i16 for scale32 = false
+  if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
+    emitOpError(
+        "expect i16 element type for multiplier for scale32=false, got ")
+        << multiplierType.getElementType();
+    return failure();
+  }
+
+  // multiplier/shift must have shape = {numChannels},
+  // where numChannel is 1 if per_channel = false
+  // otherwise numChannel is dimension in input shape's last axis
+  int64_t numChannels = 1;
+  if (getPerChannel()) {
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+    numChannels = inputShape[inputShape.size() - 1];
+  }
+
+  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+  // multiplier input has rank 1 by dialect definition
+  if (multiplierShape[0] != numChannels) {
----------------
lhutton1 wrote:

I think we need to be careful here incase the dim sizes are dynamic

https://github.com/llvm/llvm-project/pull/129720


More information about the Mlir-commits mailing list