[Mlir-commits] [mlir] [mlir][tosa] Change zero points of convolution ops to required inputs (PR #127679)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 18 10:29:59 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This patch changes the input_zp and weight_zp for convolution operators to be required inputs
in order to align with the TOSA Spec 1.0.

Convolution operators affected are:
	CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D.


Change-Id: I7aa6e05580c83c617394c3f7fd18fba8c3f3b0ec

---

Patch is 151.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127679.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (-106) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+40-8) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+41-22) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+109-49) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+17-7) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+27-27) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+61-32) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+10-6) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+51-18) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+108-117) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+17-16) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+5-3) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+8-4) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-61) 
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+5-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 069073bc2d164..c18b46c9474fc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -161,112 +161,6 @@ namespace tosa {
 std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
                                            Type srcElemType, int64_t zp = 0);
 
-// Get zero point value from the attribute argument.
-LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
-
-// Verify if zero point falls into valid range.
-template <typename T>
-LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
-  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
-                !std::is_same_v<T, DepthwiseConv2DOp> &&
-                !std::is_same_v<T, TransposeConv2DOp>) {
-    return failure();
-  }
-
-  if (!zpElemType.isIntOrFloat())
-    return failure();
-
-  if (!zpElemType.isInteger(8) && zp != 0)
-    return failure();
-
-  if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
-    return failure();
-
-  if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
-    return failure();
-
-  return success();
-}
-
-// Helper type trait to determine if an operation is a tosa convolution.
-template <typename Op>
-struct IsTosaConv : std::false_type {};
-
-template <>
-struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
-template <>
-struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
-
-template <typename Op>
-constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
-
-// Helper struct to hold the zero points of a TOSA convolution operation as
-// named 64-bit integer fields.
-struct ConvZpPair {
-  ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
-      : inputZp(inputZp), weightZp(weightZp) {}
-  std::int64_t inputZp;
-  std::int64_t weightZp;
-};
-
-// Helper function which attempts to extract the zero points from a TOSA
-// convolution by matching them against defining ops which should be tosa.const
-// operations.
-//
-// There are three possible results:
-// 1. Failed to extract the zero-points i.e. they should exist and don't or they
-// do exist but are invalid.
-// 2. Succeeded in extracting zero-points.
-// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
-// convolution.
-using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
-template <typename TosaConvOp>
-std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
-extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
-  // Strictly speaking the base TOSA spec requires that for non int8 types
-  // zero points must be zero. However, in the dialect these operands are
-  // optional and only required for int8. They have no semantic meaning for
-  // non-quantized types and can therefore be safely ignored. This is case 3.
-  if (auto opElementTY =
-          cast<ShapedType>(op->getOperand(0).getType()).getElementType();
-      !opElementTY.isInteger(8))
-    return FailOrMaybeZP(std::nullopt);
-
-  // Now we know we should have a zero point check it is valid.
-  if (!op.getInputZp())
-    return rewriter.notifyMatchFailure(op, "missing input zero point");
-
-  // Helper to extract the zero point by matching its definition against a
-  // constant.
-  auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
-    ElementsAttr zpAttr;
-    if (!matchPattern(zpValue, m_Constant(&zpAttr)))
-      return std::nullopt;
-
-    int64_t zp;
-    if (tosa::getZeroPoint(zpAttr, zp).failed())
-      return std::nullopt;
-
-    return std::make_optional(zp);
-  };
-
-  auto maybeInputZp = extractZeroPoint(op.getInputZp());
-  if (!maybeInputZp)
-    return rewriter.notifyMatchFailure(op, "unable to extract input zp");
-
-  if (!op.getWeightZp())
-    return rewriter.notifyMatchFailure(op, "missing weight zero point");
-
-  auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
-  if (!maybeWeightZp)
-    return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
-
-  return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
-}
 } // namespace tosa
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index d11ba65a13736..8ac6a81dc38af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -105,8 +105,9 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -118,6 +119,13 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -136,8 +144,9 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$input,
     TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
@@ -149,6 +158,13 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
     Tosa_Tensor5D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -168,8 +184,9 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
@@ -181,6 +198,13 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_ConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
@@ -330,8 +354,9 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
-    Optional<Tosa_ScalarTensor>:$input_zp,
-    Optional<Tosa_ScalarTensor>:$weight_zp,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
@@ -343,6 +368,13 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
     Tosa_Tensor4D:$output
   );
 
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getWeightZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyWeightZeroPoint(int64_t zp);
+  }];
+
   let builders = [Tosa_TransConvOpQuantInfoBuilder];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index a8fd536dd2548..bb7d5a23d9365 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
 
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (llvm::failed(failureOrMaybeZps))
-      return failure();
+    // Get and verify zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
+
+    if (op.verifyInputZeroPoint(inputZpVal).failed() ||
+        op.verifyWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
 
-    auto maybeZps = failureOrMaybeZps.value();
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -289,7 +299,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (maybeZps) {
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -297,11 +307,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.conv op quantization has zp outside of input range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       // For 2D convolutions, we need to check if the target convolution op
       // wants a HWCF kernel layout.
       bool wantHwcf =
-          maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
-                   : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+          hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
       if (wantHwcf) {
         // Transpose the kernel to match dimension ordering of the linalg
         // convolution operation.
@@ -376,9 +386,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (maybeZps) {
-      auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
-      auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
+    if (hasZp) {
+      auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -441,18 +451,27 @@ class DepthwiseConvConverter
         /*inputSizeDims=*/{1, 2},
         /*kernelSizeDims=*/{0, 1}, rewriter);
 
-    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
-    if (llvm::failed(failureOrMaybeZps))
-      return failure();
+    // Get and verify zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "bail out if zero points cannot statically be determined");
 
-    auto maybeZps = failureOrMaybeZps.value();
+    if (op.verifyInputZeroPoint(inputZpVal).failed() ||
+        op.verifyWeightZeroPoint(weightZpVal).failed())
+      return rewriter.notifyMatchFailure(
+          op, "zero point must be zero for non-int8 integer types");
 
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (maybeZps) {
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -460,12 +479,12 @@ class DepthwiseConvConverter
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.depthwise_conv op quantization has zp outside of input "
                 "range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -505,7 +524,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    if (!maybeZps) {
+    if (!hasZp) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
                            loc, linalgConvTy, ValueRange{input, weight},
@@ -532,8 +551,8 @@ class DepthwiseConvConverter
               .getResult(0);
       rewriter.replaceOp(op, result);
     } else {
-      IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
-      IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
+      IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
       Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 67021d6c07401..ad8077d3d47a6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,6 +211,18 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Tosa utilities.
+//===----------------------------------------------------------------------===//
+
+static Type getStorageElementTypeOrSelf(Type type) {
+  auto elementType = getElementTypeOrSelf(type);
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(elementType))
+    elementType = quantType.getStorageType();
+
+  return elementType;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Operator Verifiers.
 //===----------------------------------------------------------------------===//
@@ -243,6 +255,9 @@ static LogicalResult verifyConvOp(T op) {
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
     inputEType = quantType.getStorageType();
 
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
+    weightEType = quantType.getStorageType();
+
   if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
     biasEType = quantType.getStorageType();
 
@@ -279,36 +294,32 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // We require an explicit input zero point and weight zero point for i8
-  // convolution.
-  if (!op.getInputZp() && !op.getWeightZp())
-    return inputEType.isInteger(8) ? failure() : success();
+  auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
+  if (inputEType != inputZpEType) {
+    return op.emitOpError("expect both input and its zero point are the same "
+                          "element type, got ")
+           << inputEType << " and " << inputZpEType;
+  }
 
-  ElementsAttr inputZpAttr;
-  ElementsAttr weightZpAttr;
-  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
-      !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
-    op.emitOpError(
-        "bail out if the actual value of zero points cannot be determined");
-    return failure();
+  auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
+  if (weightEType != weightZpEType) {
+    return op.emitOpError("expect both weight and its zero point are the same "
+                          "element type, got ")
+           << weightEType << " and " << weightZpEType;
   }
 
-  // Get and verify explicit zero points.
   int64_t inputZpVal;
-  int64_t weightZpVal;
-
-  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
-          .failed()) {
-    op.emitOpError("input zero point must be zero for non-int8 integer types");
-    return failure();
+  if (op.getInputZeroPoint(inputZpVal).succeeded()) {
+    if (op.verifyInputZeroPoint(inputZpVal).failed())
+      return op.emitOpError(
+          "input zero point must be zero for non-int8 integer types");
   }
 
-  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
-      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
-          .failed()) {
-    op.emitOpError("weight zero point must be zero for non-int8 integer types");
-    return failure();
+  int64_t weightZpVal;
+  if (op.getWeightZeroPoint(weightZpVal).succeeded()) {
+    if (op.verifyWeightZeroPoint(weightZpVal).failed())
+      return op.emitOpError(
+          "weight zero point must be zero for non-int8 integer types");
   }
 
   return success();
@@ -1371,6 +1382,79 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
   return mlir::success();
 }
 
+template <typename T>
+static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
+  ElementsAttr zpAttr;
+  if (!matchPattern(val, m_Constant(&zpAttr))) {
+    return failure();
+  }
+
+  Type zpElemType = zpAttr.getElementType();
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
+    zp = quantType.getZeroPoint();
+    return success();
+  }
+
+  if (llvm::isa<FloatType>(zpElemType)) {
+    if (!zpAttr.getValues<APFloat>()[0].isZero())
+      return op.emitOpError(
+          "non-zero zero point is not allowed for float types");
+    zp = 0;
+    return success();
+  }
+
+  if (llvm::isa<IntegerType>(zpElemType)) {
+    zp = zpAttr.getValues<APInt>()[0].getSExtValue();
+    return success();
+  }
+
+  return op.emitOpError("zero point is not allowed for unsupported types");
+}
+
+template <typename T>
+static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
+  // TODO clean it up when the entire zero point (attribute -> input tensor
+  // type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
+  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
+                !std::is_same_v<T, DepthwiseConv2DOp> &&
+                !std::is_same_v<T, TransposeConv2DOp>)
+    return failure();
+
+  Type zpElemType = getElementTypeOrSelf(val);
+
+  if (!zpElemType.isIntOrFloat())
+    return op.emitOpError("zero point is not integer or float typss");
+
+  if (!zpElemType.isInteger(8) && zp != 0)
+    return op.emitOpError("zero point must be zero for non-int8 integer types");
+
+  if (zp < -128 || zp > 127)
+    return failure();
+
+  return success();
+}
+
+#define ZERO_POINT_HELPER(OP)                                                  \
+  LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) {                     \
+    return getZeroPoint(*this, getInputZp(), zp);                              \
+  }                                                            ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/127679


More information about the Mlir-commits mailing list