[Mlir-commits] [mlir] [mlir][tosa] Make Convolution Zero Points Inputs (PR #122939)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 14 09:13:39 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

<details>
<summary>Changes</summary>

The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.

---

Patch is 172.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122939.diff


21 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+35) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+10-5) 
- (modified) mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h (+3) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+68-26) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+150-37) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+39-5) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+36-8) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+57-42) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (+58-10) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+44-22) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+8-4) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+45-38) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+108-72) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-8) 
- (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+3-1) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+11-4) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+9-3) 
- (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+34-17) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-31) 
- (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+4-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..1c1c65f48cd0af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -101,4 +101,39 @@ class TosaElementwiseOperator
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 
+namespace mlir {
+namespace tosa {
+
+// Create a rank-0 const tensor for zero point of the source tensor.
+std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
+                                           Type srcElemType, int64_t zp = 0);
+
+// Get zero point value from the attribute argument.
+LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
+
+// Verify if zero point falls into valid range.
+template <typename T>
+LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
+  if constexpr (!std::is_same_v<T, Conv2DOp> &&
+                !std::is_same_v<T, Conv3DOp> &&
+                !std::is_same_v<T, DepthwiseConv2DOp> &&
+                !std::is_same_v<T, TransposeConv2DOp>) {
+    return failure();
+  }
+
+  if (!zpElemType.isIntOrFloat())
+    return failure();
+
+  if (!zpElemType.isInteger(8) && zp != 0)
+    return failure();
+
+  if (zp < -128 || zp > 127)
+    return failure();
+
+  return success();
+}
+
+} // namespace tosa
+} // namespace mlir
+
 #endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..372447e278c4e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -103,11 +103,13 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -133,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
     Tosa_Tensor5D:$input,
     TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -164,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -346,13 +350,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 5e80745777b3b3..10dc5dd36cfa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
                                                    Value input, Value weight);
 
+std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
+                                         Value weight);
+
 //// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
 MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
                                                        Value a, Value b);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..fb213461b2acb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -258,7 +259,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
-    bool isQuantized = op.getQuantizationInfo().has_value();
+
+    ElementsAttr inputZpAttr;
+    ElementsAttr weightZpAttr;
+    if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+        !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
+      return rewriter.notifyMatchFailure(
+          op,
+          "bail out if the actual value of zero points cannot be determined");
+
+    // Get and verify explicit zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+        tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(inputZpAttr),
+                                          inputZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "input zero point must be zero for non-int8 integer types");
+
+    if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+        tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(weightZpAttr),
+                                          weightZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "weight zero point must be zero for non-int8 integer types");
+
+    const bool hasZp =
+        (inputZpVal != 0) || (weightZpVal != 0) || isa<IntegerType>(inputETy);
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -284,10 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (isQuantized) {
-      auto quantizationInfo = *op.getQuantizationInfo();
-      int64_t iZp = quantizationInfo.getInputZp();
-
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -295,11 +321,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (iZp < intMin || iZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.conv op quantization has zp outside of input range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -312,7 +338,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       // For 2D convolutions, we need to check if the target convolution op
       // wants a HWCF kernel layout.
       bool wantHwcf =
-          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+          hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
                       : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
       if (wantHwcf) {
         // Transpose the kernel to match dimension ordering of the linalg
@@ -374,10 +400,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (isQuantized) {
-      auto quantizationInfo = *op.getQuantizationInfo();
-      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
-      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
+    if (hasZp) {
+      auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,25 +465,40 @@ class DepthwiseConvConverter
         /*inputSizeDims=*/{1, 2},
         /*kernelSizeDims=*/{0, 1}, rewriter);
 
-    bool isQuantized = op->hasAttr("quantization_info");
-    IntegerAttr iZp;
-    IntegerAttr kZp;
-    if (isQuantized) {
-      auto quantizationInfo =
-          cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
-      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
-      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
-    }
+    ElementsAttr inputZpAttr;
+    ElementsAttr weightZpAttr;
+    if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+        !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
+      return rewriter.notifyMatchFailure(
+          op,
+          "bail out if the actual value of zero points cannot be determined");
+
+    // Get and verify explicit zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+        tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
+            getElementTypeOrSelf(inputZpAttr), inputZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "input zero point must be zero for non-int8 integer types");
+
+    if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+        tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
+            getElementTypeOrSelf(weightZpAttr), weightZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "weight zero point must be zero for non-int8 integer types");
 
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (isQuantized) {
-      auto quantizationInfo =
-          cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
-      int64_t iZp = quantizationInfo.getInputZp();
+    if (inputZpVal) {
+      const int64_t iZp = inputZpVal;
 
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -512,7 +552,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    if (!isQuantized) {
+    if (!hasZp && isa<FloatType>(inputETy)) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
                            loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +579,8 @@ class DepthwiseConvConverter
               .getResult(0);
       rewriter.replaceOp(op, result);
     } else {
+      IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      IntegerAttr kZp = rewriter.getI32IntegerAttr(weightZpVal);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
       Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..c5c145a0eb329b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,11 +211,7 @@ static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
 
-  RankedTensorType weightType;
-  if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
-  else
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
 
   // Must be ranked tensor types
   if (!inputType) {
@@ -223,18 +219,49 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
   if (!weightType) {
-    if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
-      op.emitOpError("expect a ranked tensor for filter, got ")
-          << op.getFilter();
-    } else {
-      op.emitOpError("expect a ranked tensor for weight, got ")
-          << op.getWeight();
-    }
+    op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
     return failure();
   }
 
   auto inputEType = inputType.getElementType();
   auto weightEType = weightType.getElementType();
+  auto biasEType =
+      llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+    inputEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+    biasEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+    // for now, only enforce bias element type == result element type for
+    // float types.
+    op.emitOpError(
+        "expect both bias and result to have same element type, got ")
+        << biasEType << " and " << resultEType;
+    return failure();
+  }
+
+  if (inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN() ||
+      weightEType.isFloat8E5M2() || weightEType.isFloat8E4M3FN()) {
+    if (inputEType != weightEType) {
+      op.emitOpError(
+          "expect both input and weight to have same element type, got ")
+          << inputEType << " and " << weightEType;
+      return failure();
+    }
+  }
 
   bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
   bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
@@ -247,14 +274,33 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // Quantized type must have constructed the quantizationattr, and unquantized
-  // types should not have a quantizationattr.
-  if ((inputIsQuant && !op.getQuantizationInfo()) ||
-      (!inputIsQuant && op.getQuantizationInfo())) {
-    op.emitOpError("quantizationattr is required for quantized type, and not "
-                   "allowed for float type");
+  ElementsAttr inputZpAttr;
+  ElementsAttr weightZpAttr;
+  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+      !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
+    op.emitOpError(
+        "bail out if the actual value of zero points cannot be determined");
+    return failure();
+  }
+
+  // Get and verify explicit zero points.
+  int64_t inputZpVal;
+  int64_t weightZpVal;
+
+  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
+          .failed()) {
+    op.emitOpError("input zero point must be zero for non-int8 integer types");
     return failure();
   }
+
+  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
+          .failed()) {
+    op.emitOpError("weight zero point must be zero for non-int8 integer types");
+    return failure();
+  }
+
   return success();
 }
 
@@ -314,6 +360,39 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+// verify that inType and outType have same element types
+template <typename T>
+static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
+  auto inputType = llvm::dyn_cast<TensorType>(inType);
+  auto outputType = llvm::dyn_cast<TensorType>(outType);
+  if (!inputType) {
+    op.emitOpError("expect shaped tensor for input, got ") << inType;
+    return failure();
+  }
+  if (!outputType) {
+    op.emitOpError("expect shaped tensor for output, got ") << outType;
+    return failure();
+  }
+  auto inputElementType = inputType.getElementType();
+  auto outputElementType = outputType.getElementType();
+  auto inputQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
+  auto outputQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
+  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
+      (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
+      inputElementType != outputElementType) {
+    // only check if both element types are int/index/float/UniformQuantized
+    // eg, not sure how to check quant::QuantizedType
+    // this happens in test_conv2d_q_grouped_convolution in
+    // tfl-to-tosa-pipeline.mlir
+    op.emitOpError("expect input and output to have same element type, got ")
+        << inputElementType << " and " << outputElementType;
+    return failure();
+  }
+  return success();
+}
+
 LogicalResult tosa::ArgMaxOp::verify() {
   // Ensure output is of 32-bit integer
   const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -413,21 +492,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                      DenseI64ArrayAttr stride,
                                      DenseI64ArrayAttr dilation,
                                      TypeAttr accType) {
-
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("pad", pad);
   result.addAttribute("stride", stride);
   result.addAttribute("dilation", dilation);
   result.addAttribute("acc_type", accType);
-
-  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
-  if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
-    result.addTypes(
-        buildConvOpResultTypeInfo(builder, outputType, input, weight));
-  } else {
-    result.addTypes(outputType);
-  }
+  result.addTypes(outputType);
 }
 
 /// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -782,7 +853,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
   return success();
 }
 
-LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
+LogicalResult FullyConnectedOp::verify() { return success(); }
 
 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
@@ -1850,7 +1921,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape(adaptor.getFilter().getType());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
   if (weightShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? weightShape.getDimSize(0)
@@ -1906,10 +1977,8 @@ LogicalResult IfOp::inferReturnTypeComponents(
       if (...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/122939


More information about the Mlir-commits mailing list