[Mlir-commits] [mlir] [mlir][tosa] Change zero points of convolution ops to required inputs (PR #127679)

Tai Ly llvmlistbot at llvm.org
Wed Feb 19 08:49:04 PST 2025


================
@@ -279,36 +294,32 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // We require an explicit input zero point and weight zero point for i8
-  // convolution.
-  if (!op.getInputZp() && !op.getWeightZp())
-    return inputEType.isInteger(8) ? failure() : success();
+  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
+  if (inputEType != inputZpEType) {
+    return op.emitOpError("expect both input and its zero point are the same "
+                          "element type, got ")
+           << inputEType << " and " << inputZpEType;
+  }
 
-  ElementsAttr inputZpAttr;
-  ElementsAttr weightZpAttr;
-  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
-      !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
-    op.emitOpError(
-        "bail out if the actual value of zero points cannot be determined");
-    return failure();
+  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
+  if (weightEType != weightZpEType) {
+    return op.emitOpError("expect both weight and its zero point are the same "
+                          "element type, got ")
+           << weightEType << " and " << weightZpEType;
   }
 
-  // Get and verify explicit zero points.
   int64_t inputZpVal;
-  int64_t weightZpVal;
-
-  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
-          .failed()) {
-    op.emitOpError("input zero point must be zero for non-int8 integer types");
-    return failure();
+  if (op.getInputZeroPoint(inputZpVal).succeeded()) {
+    if (op.verifyInputZeroPoint(inputZpVal).failed())
+      return op.emitOpError(
+          "input zero point must be zero for non-int8 integer types");
   }
 
-  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
-          .failed()) {
-    op.emitOpError("weight zero point must be zero for non-int8 integer types");
-    return failure();
+  int64_t weightZpVal;
+  if (op.getWeightZeroPoint(weightZpVal).succeeded()) {
----------------
Tai78641 wrote:

done

https://github.com/llvm/llvm-project/pull/127679


More information about the Mlir-commits mailing list