[Mlir-commits] [mlir] [mlir][tosa] Make Convolution Zero Points Inputs (PR #122939)

Jack Frankland llvmlistbot at llvm.org
Thu Jan 30 03:19:32 PST 2025


https://github.com/FranklandJack updated https://github.com/llvm/llvm-project/pull/122939

>From 4b6a660e5a7f58cf08e1930a2816bfdb8fa3eeb6 Mon Sep 17 00:00:00 2001
From: TatWai Chong <tatwai.chong at arm.com>
Date: Wed, 11 Sep 2024 21:37:20 -0700
Subject: [PATCH] [mlir][tosa] Make Convolution Zero Points Inputs

The TOSA-v1.0 specification moves the the "zero point" parameters of the
convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and
TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs
and update any transformations, materializations and lit tests
appropriately.
---
 .../mlir/Dialect/Tosa/IR/TosaOpBase.td        |   7 +
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h   | 118 ++++++++++
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  |  22 +-
 .../mlir/Dialect/Tosa/IR/TosaTypesBase.td     |   2 +
 .../mlir/Dialect/Tosa/Utils/QuantUtils.h      |   3 +
 .../TosaToLinalg/TosaToLinalgNamed.cpp        |  55 +++--
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 210 +++++++++++++++---
 .../Tosa/Transforms/TosaDecomposeConv2D.cpp   |  29 ++-
 .../Transforms/TosaDecomposeDepthwise.cpp     |  25 ++-
 .../Transforms/TosaDecomposeTransposeConv.cpp |  81 ++++---
 .../Tosa/Transforms/TosaValidation.cpp        |   2 +-
 mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp    |  63 +++++-
 .../TosaToLinalg/tosa-to-linalg-named.mlir    |  19 +-
 mlir/test/Dialect/Tosa/invalid.mlir           |  41 ++--
 mlir/test/Dialect/Tosa/ops.mlir               |   3 +-
 mlir/test/Dialect/Tosa/quant-test.mlir        |   4 +-
 .../Dialect/Tosa/tosa-decompose-conv2d.mlir   |  12 +-
 .../Tosa/tosa-decompose-depthwise.mlir        |   4 +-
 .../Tosa/tosa-decompose-transpose-conv.mlir   |  44 ++--
 19 files changed, 546 insertions(+), 198 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 29afd6c27302cc..14711aea5c1a8e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+// The "SameVariadicOperandSize" trait allows us to pass optional arguments
+// for multiple zero points in convolution ops.
+class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
+    : Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
+      [SameVariadicOperandSize])> {
+}
+
 #endif // TOSA_OP_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 27061002b02957..069073bc2d164a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -16,6 +16,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "mlir/Dialect/Traits.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -29,6 +30,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 class PatternRewriter;
@@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t);
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 
+namespace mlir {
+namespace tosa {
+
+// Create a rank-1 const tensor for zero point of the source tensor.
+std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
+                                           Type srcElemType, int64_t zp = 0);
+
+// Get zero point value from the attribute argument.
+LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
+
+// Verify if zero point falls into valid range.
+template <typename T>
+LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
+  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
+                !std::is_same_v<T, DepthwiseConv2DOp> &&
+                !std::is_same_v<T, TransposeConv2DOp>) {
+    return failure();
+  }
+
+  if (!zpElemType.isIntOrFloat())
+    return failure();
+
+  if (!zpElemType.isInteger(8) && zp != 0)
+    return failure();
+
+  if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
+    return failure();
+
+  if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
+    return failure();
+
+  return success();
+}
+
+// Helper type trait to determine if an operation is a tosa convolution.
+template <typename Op>
+struct IsTosaConv : std::false_type {};
+
+template <>
+struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
+template <>
+struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
+template <>
+struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
+template <>
+struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
+
+template <typename Op>
+constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
+
+// Helper struct to hold the zero points of a TOSA convolution operation as
+// named 64-bit integer fields.
+struct ConvZpPair {
+  ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
+      : inputZp(inputZp), weightZp(weightZp) {}
+  std::int64_t inputZp;
+  std::int64_t weightZp;
+};
+
+// Helper function which attempts to extract the zero points from a TOSA
+// convolution by matching them against defining ops which should be tosa.const
+// operations.
+//
+// There are three possible results:
+// 1. Failed to extract the zero-points i.e. they should exist and don't or they
+// do exist but are invalid.
+// 2. Succeeded in extracting zero-points.
+// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
+// convolution.
+using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
+template <typename TosaConvOp>
+std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
+extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
+  // Strictly speaking the base TOSA spec requires that for non int8 types
+  // zero points must be zero. However, in the dialect these operands are
+  // optional and only required for int8. They have no semantic meaning for
+  // non-quantized types and can therefore be safely ignored. This is case 3.
+  if (auto opElementTY =
+          cast<ShapedType>(op->getOperand(0).getType()).getElementType();
+      !opElementTY.isInteger(8))
+    return FailOrMaybeZP(std::nullopt);
+
+  // Now we know we should have a zero point check it is valid.
+  if (!op.getInputZp())
+    return rewriter.notifyMatchFailure(op, "missing input zero point");
+
+  // Helper to extract the zero point by matching its definition against a
+  // constant.
+  auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
+    ElementsAttr zpAttr;
+    if (!matchPattern(zpValue, m_Constant(&zpAttr)))
+      return std::nullopt;
+
+    int64_t zp;
+    if (tosa::getZeroPoint(zpAttr, zp).failed())
+      return std::nullopt;
+
+    return std::make_optional(zp);
+  };
+
+  auto maybeInputZp = extractZeroPoint(op.getInputZp());
+  if (!maybeInputZp)
+    return rewriter.notifyMatchFailure(op, "unable to extract input zp");
+
+  if (!op.getWeightZp())
+    return rewriter.notifyMatchFailure(op, "missing weight zero point");
+
+  auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
+  if (!maybeWeightZp)
+    return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
+
+  return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
+}
+} // namespace tosa
+} // namespace mlir
+
 #endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 9e3e41d288e4ac..80917cff84e3dc 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: conv2d
 //===----------------------------------------------------------------------===//
-def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
+def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
   let summary = "2D Convolution Operator";
 
   let description = [{
@@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Optional<Tosa_ZeroPointTensor>:$input_zp,
+    Optional<Tosa_ZeroPointTensor>:$weight_zp,
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: conv3d
 //===----------------------------------------------------------------------===//
-def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
+def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
   let summary = "3D Convolution operator";
 
   let description = [{
@@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
     Tosa_Tensor5D:$input,
     TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
+    Optional<Tosa_ZeroPointTensor>:$input_zp,
+    Optional<Tosa_ZeroPointTensor>:$weight_zp,
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
 //===----------------------------------------------------------------------===//
 // Operator: depthwise_conv2d
 //===----------------------------------------------------------------------===//
-def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
+def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
   let summary = "Depthwise 2D Convolution operator";
 
   let description = [{
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Optional<Tosa_ZeroPointTensor>:$input_zp,
+    Optional<Tosa_ZeroPointTensor>:$weight_zp,
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: transpose_conv2d
 //===----------------------------------------------------------------------===//
-def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
+def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
   let summary = "Transpose 2D Convolution operator.";
 
   let description = [{
@@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Optional<Tosa_ZeroPointTensor>:$input_zp,
+    Optional<Tosa_ZeroPointTensor>:$weight_zp,
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 5693acf3a01db4..f970e85691bda3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -288,4 +288,6 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
 def Rank2TosaShape : TosaShapeOfRank<2>;
 def Rank4TosaShape : TosaShapeOfRank<4>;
 
+def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;
+
 #endif // TOSA_TYPES_BASE
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 5e80745777b3b3..10dc5dd36cfa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
                                                    Value input, Value weight);
 
+std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
+                                         Value weight);
+
 //// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
 MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
                                                        Value a, Value b);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 57a5fe75a007b7..214c10cdc7d1c6 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
-    bool isQuantized = op.getQuantizationInfo().has_value();
+
+    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+    if (llvm::failed(failureOrMaybeZps))
+      return failure();
+
+    auto maybeZps = failureOrMaybeZps.value();
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -284,10 +289,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (isQuantized) {
-      auto quantizationInfo = *op.getQuantizationInfo();
-      int64_t iZp = quantizationInfo.getInputZp();
-
+    if (maybeZps) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -295,11 +297,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (iZp < intMin || iZp > intMax)
+      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.conv op quantization has zp outside of input range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       // For 2D convolutions, we need to check if the target convolution op
       // wants a HWCF kernel layout.
       bool wantHwcf =
-          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
-                      : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+          maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                   : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
       if (wantHwcf) {
         // Transpose the kernel to match dimension ordering of the linalg
         // convolution operation.
@@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (isQuantized) {
-      auto quantizationInfo = *op.getQuantizationInfo();
-      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
-      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
+    if (maybeZps) {
+      auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
+      auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,26 +441,18 @@ class DepthwiseConvConverter
         /*inputSizeDims=*/{1, 2},
         /*kernelSizeDims=*/{0, 1}, rewriter);
 
-    bool isQuantized = op->hasAttr("quantization_info");
-    IntegerAttr iZp;
-    IntegerAttr kZp;
-    if (isQuantized) {
-      auto quantizationInfo =
-          cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
-      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
-      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
-    }
+    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+    if (llvm::failed(failureOrMaybeZps))
+        return failure();
+
+    auto maybeZps = failureOrMaybeZps.value();
 
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (isQuantized) {
-      auto quantizationInfo =
-          cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
-      int64_t iZp = quantizationInfo.getInputZp();
-
+    if (maybeZps) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -467,12 +460,12 @@ class DepthwiseConvConverter
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (iZp < intMin || iZp > intMax)
+      if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.depthwise_conv op quantization has zp outside of input "
                 "range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -512,7 +505,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    if (!isQuantized) {
+    if (!maybeZps) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
                            loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +532,8 @@ class DepthwiseConvConverter
               .getResult(0);
       rewriter.replaceOp(op, result);
     } else {
+      IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
+      IntegerAttr kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
       Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c0b419b6f473c8..bf6bc3f10a801f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -217,33 +217,59 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type,
 
 template <typename T>
 static LogicalResult verifyConvOp(T op) {
-  // All TOSA conv ops have an input() and weight().
+  // All TOSA conv ops have an input and weight arguments which must be ranked
+  // tensors.
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
-
-  RankedTensorType weightType;
-  if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
-  else
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
-
-  // Must be ranked tensor types
   if (!inputType) {
     op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
     return failure();
   }
+
+  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
   if (!weightType) {
-    if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
-      op.emitOpError("expect a ranked tensor for filter, got ")
-          << op.getFilter();
-    } else {
-      op.emitOpError("expect a ranked tensor for weight, got ")
-          << op.getWeight();
-    }
+    op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
     return failure();
   }
 
   auto inputEType = inputType.getElementType();
   auto weightEType = weightType.getElementType();
+  auto biasEType =
+      llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+    inputEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+    biasEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+    // for now, only enforce bias element type == result element type for
+    // float types.
+    op.emitOpError(
+        "expect both bias and result to have same element type, got ")
+        << biasEType << " and " << resultEType;
+    return failure();
+  }
+
+  if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
+      isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
+    if (inputEType != weightEType) {
+      op.emitOpError(
+          "expect both input and weight to have same element type, got ")
+          << inputEType << " and " << weightEType;
+      return failure();
+    }
+  }
 
   bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
   bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
@@ -256,14 +282,53 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // Quantized type must have constructed the quantizationattr, and unquantized
-  // types should not have a quantizationattr.
-  if ((inputIsQuant && !op.getQuantizationInfo()) ||
-      (!inputIsQuant && op.getQuantizationInfo())) {
-    op.emitOpError("quantizationattr is required for quantized type, and not "
-                   "allowed for float type");
-    return failure();
+  // We require an explicit input zero point and weight zero point for i8
+  // convolution. It isn't invalid to provide zero points for other types but
+  // they will be ignored.
+  if (inputEType.isInteger(8)) {
+
+    // Check we have zero points.
+    if (!op.getInputZp()) {
+      op.emitError("missing input zero point");
+      return failure();
+    }
+
+    if (!op.getWeightZp()) {
+      op.emitError("missing weight zero point");
+      return failure();
+    }
+
+    ElementsAttr inputZpAttr;
+    ElementsAttr weightZpAttr;
+    if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+        !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
+      op.emitOpError(
+          "bail out if the actual value of zero points cannot be determined");
+      return failure();
+    }
+
+    // Get and verify explicit zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+        tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
+            .failed()) {
+      op.emitOpError(
+          "input zero point must be zero for non-int8 integer types");
+      return failure();
+    }
+
+    if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+        tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr),
+                                 weightZpVal)
+            .failed()) {
+      op.emitOpError(
+          "weight zero point must be zero for non-int8 integer types");
+      return failure();
+    }
   }
+
   return success();
 }
 
@@ -322,6 +387,39 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+// verify that inType and outType have same element types
+template <typename T>
+static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
+  auto inputType = llvm::dyn_cast<TensorType>(inType);
+  auto outputType = llvm::dyn_cast<TensorType>(outType);
+  if (!inputType) {
+    op.emitOpError("expect shaped tensor for input, got ") << inType;
+    return failure();
+  }
+  if (!outputType) {
+    op.emitOpError("expect shaped tensor for output, got ") << outType;
+    return failure();
+  }
+  auto inputElementType = inputType.getElementType();
+  auto outputElementType = outputType.getElementType();
+  auto inputQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
+  auto outputQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
+  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
+      (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
+      inputElementType != outputElementType) {
+    // only check if both element types are int/index/float/UniformQuantized
+    // eg, not sure how to check quant::QuantizedType
+    // this happens in test_conv2d_q_grouped_convolution in
+    // tfl-to-tosa-pipeline.mlir
+    op.emitOpError("expect input and output to have same element type, got ")
+        << inputElementType << " and " << outputElementType;
+    return failure();
+  }
+  return success();
+}
+
 LogicalResult tosa::ArgMaxOp::verify() {
   // Ensure output is of 32-bit integer
   const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -421,21 +519,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                      DenseI64ArrayAttr stride,
                                      DenseI64ArrayAttr dilation,
                                      TypeAttr accType) {
-
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("pad", pad);
   result.addAttribute("stride", stride);
   result.addAttribute("dilation", dilation);
   result.addAttribute("acc_type", accType);
-
-  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
-  if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
-    result.addTypes(
-        buildConvOpResultTypeInfo(builder, outputType, input, weight));
-  } else {
-    result.addTypes(outputType);
-  }
+  result.addTypes(outputType);
 }
 
 /// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -790,7 +880,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
   return success();
 }
 
-LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
+LogicalResult FullyConnectedOp::verify() { return success(); }
 
 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
@@ -1984,7 +2074,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape(adaptor.getFilter().getType());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
   if (weightShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? weightShape.getDimSize(0)
@@ -2280,6 +2370,54 @@ void WhileOp::print(OpAsmPrinter &parser) {
   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
 }
 
+LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
+  Type zpElemType = zpAttr.getElementType();
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(zpElemType)) {
+    zp = quantType.getZeroPoint();
+    return success();
+  }
+  if (llvm::isa<FloatType>(zpElemType)) {
+    // non-zero zero point is not allowed for float types.
+    if (!zpAttr.getValues<APFloat>()[0].isZero())
+      return failure();
+    zp = 0;
+    return success();
+  }
+  if (llvm::isa<IntegerType>(zpElemType)) {
+    zp = zpAttr.getValues<APInt>()[0].getSExtValue();
+    return success();
+  }
+  // zero point is not allowed for unsupported types.
+  return failure();
+}
+
+// Create a rank-0 const tensor for zero point of the source tensor.
+std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
+                                                       Location loc,
+                                                       Type srcElemType,
+                                                       int64_t zp) {
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
+    srcElemType = quantType.getStorageType();
+
+  auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+  if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
+    srcElemType = quantType.getStorageType();
+  if (llvm::isa<FloatType>(srcElemType)) {
+    auto zpAttr = DenseElementsAttr::get(
+        zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
+    return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
+  }
+  if (llvm::isa<IntegerType>(srcElemType)) {
+    auto zpAttr =
+        DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
+    return builder.create<tosa::ConstOp>(loc, zpType, zpAttr);
+  }
+  llvm::errs() << "zero point is not allowed for unsupported data types\n";
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // TOSA Shape and Shape Operators Helper functions.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
index cb08360f902286..b34472ea453b54 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -59,19 +59,17 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     for (const auto &it : llvm::enumerate(padAttr))
       pad[it.index() + 2] = it.value();
 
+    Type inputETy = inputType.getElementType();
     if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
-      Type inputETy = inputType.getElementType();
-      Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
-      if (op.getQuantizationInfo()) {
-        auto quantizationInfo = op.getQuantizationInfo();
-        int64_t iZp = quantizationInfo->getInputZp();
+      auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+      if (failed(failureOrMaybeZps))
+          return failure();
 
-        if (!validIntegerRange(cast<IntegerType>(inputETy), iZp))
-          return rewriter.notifyMatchFailure(
-              op, "tosa.conv op quantization has zp outside of input range");
+      auto maybeZps = failureOrMaybeZps.value();
 
-        zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
-      }
+      Attribute zeroAttr =
+          maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp)
+                   : rewriter.getZeroAttr(inputETy);
 
       llvm::SmallVector<int64_t> newShape(inputType.getShape());
 
@@ -125,13 +123,20 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> {
     auto fullyConnectedShapeType =
         RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
 
+    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+    if (failed(failureOrMaybeZps))
+      return failure();
+
+    auto maybeZps = failureOrMaybeZps.value();
     Value fullyConnectedValue;
-    if (op.getQuantizationInfo()) {
+    if (maybeZps) {
+      auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>(
+          maybeZps->inputZp, maybeZps->weightZp);
       fullyConnectedValue =
           rewriter
               .create<tosa::FullyConnectedOp>(
                   op.getLoc(), fullyConnectedShapeType, reshapedInput,
-                  reshapedWeight, op.getBias(), *op.getQuantizationInfo())
+                  reshapedWeight, op.getBias(), zeroPointAttr)
               .getResult();
     } else {
       fullyConnectedValue = rewriter
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 181aff3a9ce04f..ee857f1998a54d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -61,20 +61,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
                     rewriter.getDenseI64ArrayAttr(revisedInputShape))
                 .getResult();
 
-    if (inputType.getElementType() != resultType.getElementType()) {
-      inputType = inputType.clone(resultType.getElementType());
+    Type inputETy = inputType.getElementType();
+    Type weightETy = weightType.getElementType();
+    Type resultETy = resultType.getElementType();
+
+    if (inputETy != resultETy) {
+      inputType = inputType.clone(resultETy);
       input = rewriter.create<tosa::CastOp>(op.getLoc(), inputType, input);
     }
 
-    if (weightType.getElementType() != resultType.getElementType()) {
-      weightType = weightType.clone(resultType.getElementType());
+    if (weightETy != resultETy) {
+      weightType = weightType.clone(resultETy);
       weight = rewriter.create<tosa::CastOp>(op.getLoc(), weightType, weight);
     }
 
-    if (auto quantizationInfo = op.getQuantizationInfo()) {
-      auto iZp = quantizationInfo->getInputZp();
-      auto wZp = quantizationInfo->getWeightZp();
+    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+    if (failed(failureOrMaybeZps))
+      return failure();
 
+    auto maybeZps = failureOrMaybeZps.value();
+    if (maybeZps) {
       auto applyZp = [&](Value val, int64_t zp) -> Value {
         if (zp == 0)
           return val;
@@ -89,8 +95,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
                                             zpVal);
       };
 
-      input = applyZp(input, iZp);
-      weight = applyZp(weight, wZp);
+      input = applyZp(input, maybeZps->inputZp);
+      weight = applyZp(weight, maybeZps->weightZp);
     }
 
     ArrayRef<int64_t> padAttr = op.getPad();
@@ -99,7 +105,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
       pad[it.index() + 2] = it.value();
 
     if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) {
-      Type inputETy = inputType.getElementType();
       Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
 
       llvm::SmallVector<int64_t> newShape(inputType.getShape());
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 807f9cd683bb8c..b12a1f3db18d86 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -69,22 +69,12 @@ class TransposeConvNonStridedConverter
     auto reverse2 = rewriter.create<tosa::ReverseOp>(
         loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2));
 
-    Value conv2d;
-    if (op.getQuantizationInfo()) {
-      conv2d = rewriter.create<tosa::Conv2DOp>(
-          loc, resultTy, input, reverse2, bias,
-          rewriter.getDenseI64ArrayAttr(convPad),
-          rewriter.getDenseI64ArrayAttr(stride),
-          rewriter.getDenseI64ArrayAttr({1, 1}),
-          /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
-    } else {
-      conv2d = rewriter.create<tosa::Conv2DOp>(
-          loc, resultTy, input, reverse2, bias,
-          rewriter.getDenseI64ArrayAttr(convPad),
-          rewriter.getDenseI64ArrayAttr(stride),
-          rewriter.getDenseI64ArrayAttr({1, 1}),
-          /* acc_type = */ op.getAccTypeAttr());
-    }
+    Value conv2d = rewriter.create<tosa::Conv2DOp>(
+        loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(),
+        rewriter.getDenseI64ArrayAttr(convPad),
+        rewriter.getDenseI64ArrayAttr(stride),
+        rewriter.getDenseI64ArrayAttr({1, 1}),
+        /* acc_type = */ op.getAccType());
 
     rewriter.replaceOp(op, conv2d);
     return success();
@@ -144,12 +134,16 @@ class TransposeConvStridedConverter
     Value weightPaddingVal =
         getTosaConstShape(rewriter, op->getLoc(), weightPadding);
 
-    if (op.getQuantizationInfo().has_value()) {
-      auto quantInfo = op.getQuantizationInfo().value();
+    auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
+    if (failed(failureOrMaybeZps))
+      return failure();
+
+    auto maybeZps = failureOrMaybeZps.value();
+    if (maybeZps) {
       weight = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
           weightPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
+          rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp));
 
     } else {
       weight = CreateOpAndInferShape<tosa::PadOp>(
@@ -205,12 +199,11 @@ class TransposeConvStridedConverter
     Value inputPaddingVal =
         getTosaConstShape(rewriter, op->getLoc(), inputPadding);
 
-    if (op.getQuantizationInfo().has_value()) {
-      auto quantInfo = op.getQuantizationInfo().value();
+    if (maybeZps) {
       input = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
           inputPaddingVal, nullptr,
-          rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
+          rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp));
     } else {
       input = CreateOpAndInferShape<tosa::PadOp>(
           rewriter, loc, UnrankedTensorType::get(inputETy), input,
@@ -227,28 +220,32 @@ class TransposeConvStridedConverter
                                   biasETy),
             rewriter.getZeroAttr(biasETy)));
 
-    // Perform the convolution using the zero bias.
-    Value conv2d;
-    if (op.getQuantizationInfo()) {
-      conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
-                   rewriter, loc, UnrankedTensorType::get(resultETy), input,
-                   weight, zeroBias,
-                   /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
-                   /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
-                   /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
-                   /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
-                   .getResult();
-    } else {
-      conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
-                   rewriter, loc, UnrankedTensorType::get(resultETy), input,
-                   weight, zeroBias,
-                   /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
-                   /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
-                   /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
-                   /* acc_type = */ op.getAccTypeAttr())
-                   .getResult();
+    Value inputZp, weightZp;
+    if (maybeZps) {
+      auto maybeInputZp = createZeroPointTensor(
+          rewriter, loc, getElementTypeOrSelf(input.getType()), maybeZps->inputZp);
+      auto maybeWeightZp = createZeroPointTensor(
+          rewriter, loc, getElementTypeOrSelf(weight.getType()), maybeZps->weightZp);
+
+      if (!maybeInputZp.has_value() || !maybeWeightZp.has_value()) {
+        return rewriter.notifyMatchFailure(
+            op, "fail to create a const zero point tensor");
+      }
+
+      inputZp = *maybeInputZp;
+      weightZp = *maybeWeightZp;
     }
 
+    // Perform the convolution using the zero bias.
+    Value conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
+                       rewriter, loc, UnrankedTensorType::get(resultETy), input,
+                       weight, zeroBias, inputZp, weightZp,
+                       /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
+                       /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+                       /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+                       /* acc_type = */ op.getAccType())
+                       .getResult();
+
     // Factor the resulting width / height.
     ShapedType convTy = cast<ShapedType>(conv2d.getType());
     Type convETy = convTy.getElementType();
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index a49870687fdc60..678bb47935bd20 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -357,7 +357,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
   bool levelCheckTransposeConv2d(Operation *op) {
     if (auto transpose = dyn_cast<tosa::TransposeConv2DOp>(op)) {
       if (ShapedType filterType =
-              dyn_cast<ShapedType>(transpose.getFilter().getType())) {
+              dyn_cast<ShapedType>(transpose.getWeight().getType())) {
         auto shape = filterType.getShape();
         assert(shape.size() == 4);
         // level check kernel sizes for kH and KW
diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
index 5c546f59cde413..0f7562767001cf 100644
--- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp
@@ -112,19 +112,14 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier,
 #define GET_QTYPE(inputType)                                                   \
   (llvm::dyn_cast<quant::QuantizedType>((inputType).getElementType()))
 
-/// Method to build ConvOpQuantizationAttr, called from
-/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
-/// input_zp: input zeropoint
-/// weight_zp: weight zeropoint.
-ConvOpQuantizationAttr
-mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
-                                        Value weight) {
+static std::optional<std::pair<std::int64_t, std::int64_t>>
+getConvZeroPoints(Value input, Value weight) {
 
   auto inputType = dyn_cast<ShapedType>(input.getType());
   auto weightType = dyn_cast<ShapedType>(weight.getType());
 
   if (!inputType || !weightType)
-    return nullptr;
+    return std::nullopt;
 
   auto inputQType = GET_UQTYPE(inputType);
   auto weightPerTensorQType = GET_UQTYPE(weightType);
@@ -150,10 +145,58 @@ mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
       weightZp = weightPerAxisQType.getZeroPoints().front();
     }
 
-    return builder.getAttr<tosa::ConvOpQuantizationAttr>(inputZp, weightZp);
+    return std::make_pair(inputZp, weightZp);
   }
 
-  return nullptr;
+  return std::nullopt;
+}
+
+std::pair<Value, Value>
+mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) {
+  std::int64_t inputZp, weightZp;
+
+  auto inputEType = getElementTypeOrSelf(input.getType());
+  auto weightEType = getElementTypeOrSelf(weight.getType());
+
+  if (mlir::isa<FloatType>(inputEType) && mlir::isa<FloatType>(weightEType)) {
+    inputZp = 0;
+    weightZp = 0;
+  } else {
+    auto maybeZps = getConvZeroPoints(input, weight);
+    if (!maybeZps.has_value())
+      return {};
+
+    inputZp = maybeZps->first;
+    weightZp = maybeZps->second;
+  }
+
+  auto maybeInputZpValue =
+      createZeroPointTensor(builder, input.getLoc(), inputEType, inputZp);
+  if (!maybeInputZpValue.has_value())
+    return {};
+
+  auto maybeWeightZpValue =
+      createZeroPointTensor(builder, weight.getLoc(), weightEType, weightZp);
+  if (!maybeWeightZpValue.has_value())
+    return {};
+
+  return std::make_pair(*maybeInputZpValue, *maybeWeightZpValue);
+}
+
+/// Method to build ConvOpQuantizationAttr, called from
+/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder:
+/// input_zp: input zeropoint
+/// weight_zp: weight zeropoint.
+ConvOpQuantizationAttr
+mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input,
+                                        Value weight) {
+
+  auto maybeZps = getConvZeroPoints(input, weight);
+  if (!maybeZps.has_value())
+    return nullptr;
+
+  return builder.getAttr<tosa::ConvOpQuantizationAttr>(maybeZps->first,
+                                                       maybeZps->second);
 }
 
 /// Builds MatMulOpQuantizationAttr, called from
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5eeaebb384e408..116cd045aa0d3a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -544,7 +544,8 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
   // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
 
-  %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32>
   return
 }
 
@@ -687,7 +688,9 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C22]]
   // CHECK: linalg.conv_2d_nhwc_fhwc_q
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
+  %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x1024xi32>
   return
 }
 
@@ -799,7 +802,9 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3
   // CHECK:   [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
   // CHECK:   linalg.yield [[ADD]] : i32
   // CHECK: } -> tensor<1x12x12x512xi32>
-  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32>
+  %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x512xi32>
   return
 }
 
@@ -823,7 +828,9 @@ func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 :
   // CHECK:   [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32
   // CHECK:   linalg.yield [[ADD]] : i32
   // CHECK: } -> tensor<1x10x10x512xi32>
-  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>)  -> tensor<1x10x10x512xi32>
+  %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 2> } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>)  -> tensor<1x10x10x512xi32>
   return
 }
 
@@ -905,7 +912,9 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
   // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
   // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
 
-  %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>)  -> tensor<1x47x45x43x28xi32>
+  %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.conv3d %input, %weights, %bias , %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>)  -> tensor<1x47x45x43x28xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 51d7f828510613..76a722bc83e7d6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -33,45 +33,41 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>,
 // -----
 
 func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{expect a ranked tensor for input, got <block argument> of type 'tensor<*xi8>' at index: 0}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+           : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
 }
 
 // -----
 
 func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
-  return %0 : tensor<1x27x27x16xi8>
-}
-
-// -----
-
-func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
-  // expected-error at +1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
-           : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+           : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
 }
 
 // -----
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
-           : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+           : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8>
   return %0 : tensor<1x27x27x16xi8>
 }
 
 // -----
 
 func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> {
+  %input_zp = "tosa.const"() {value = dense<0> : tensor<1xi16>} : () -> tensor<1xi16>
+  %weight_zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
-           : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
+           : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16>
   return %0 : tensor<1x27x27x16xi16>
 }
 
@@ -123,25 +119,28 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3
 // -----
 
 func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> {
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}}
-  %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>}
-           : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8>
+  %0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>}
+           : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8>
   return %0 : tensor<1x4x8x21x34xi8>
 }
 
 // -----
 
 func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> {
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}}
-  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8>
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8>
   return %0 : tensor<1x4x4x8xi8>
 }
 
 // -----
 
 func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> {
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
   // expected-error at +1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}}
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8>
   return %0 : tensor<1x32x32x16xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 9eba2f7e5a06e4..3ac6b4b5f8b0bd 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -63,7 +63,8 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %
 func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> {
   %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
   %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
-  %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
+  %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8>
+  %2 = "tosa.conv2d"(%arg0, %0, %1, %zp, %zp) {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x3xi32>
   %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir
index 6437f12e3ff85e..ee6caf285a2484 100644
--- a/mlir/test/Dialect/Tosa/quant-test.mlir
+++ b/mlir/test/Dialect/Tosa/quant-test.mlir
@@ -12,7 +12,9 @@ func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform<u8<1:255>:f32
 // CHECK-LABEL: test_build_mult_and_shift
 func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, %arg1 : tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>> {
   // CHECK: tosa.conv2d
-  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -1, weight_zp = 0>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
+  %input_zp = "tosa.const"() {value = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2, %input_zp, %weight_zp) {acc_type = i32, pad = array<i64: 1, 1, 2, 2>, dilation = array<i64: 2, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8x!quant.uniform<i8:f32, 0.015684768557548523>>, tensor<16x1x1x8x!quant.uniform<i8<-127:127>:f32, 0.015680249780416489>>, tensor<16xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
   return %0 : tensor<1x32x32x16x!quant.uniform<i32:f32, 0.078431375324726104>>
 
 }
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
index 95d9bb1b98ab74..685f799bd3d2bf 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -33,7 +33,9 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t
   // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>}
   // CHECK-SAME: -> tensor<4x10x10x3xi32>
   // CHECK: return %[[VAR3]]
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32>
+  %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x3xi32>
   return %0 : tensor<4x10x10x3xi32>
 }
 
@@ -50,7 +52,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384
 // CHECK:           %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
 // CHECK:           return %[[VAL_6]] : tensor<?x14x14x384xi32>
 // CHECK:         }
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32>
+  %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<?x14x14x384xi32>
   return %0 : tensor<?x14x14x384xi32>
 }
 
@@ -65,6 +69,8 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1:
   // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>}
   // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>}
   // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>}
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32>
+  %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32>
   return %0 : tensor<4x12x12x3xi32>
 }
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index 5f36dd3b3d137c..ce29d1a498b4f8 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -38,7 +38,9 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<
   // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array<i64: 4, 10, 10, 6>}
   // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 6>}
   // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]]
-  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>, quantization_info = #tosa.conv_quant<input_zp = 7, weight_zp = 11>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
+  %input_zp = "tosa.const"() {value = dense<7> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
   return %0 : tensor<4x10x10x6xi32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 12691f2e325a28..bb6de82ee10532 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -6,7 +6,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x
   // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
   // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2
   // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32>
   return %0 : tensor<2x18x19x5xf32>
 }
 
@@ -15,10 +15,14 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x
 // CHECK-LABEL: @transpose_conv2d_quantized
 
 func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) {
+  // CHECK-DAG: %[[INPUT_ZP:.+]]  = "tosa.const"() <{value = dense<-6> : tensor<1xi8>}
+  // CHECK-DAG: %[[WEIGHT_ZP:.+]]  = "tosa.const"() <{value = dense<11> : tensor<1xi8>}
   // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
   // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32}
-  // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32>
+  // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 2, 2, 5, 5>, stride = array<i64: 1, 1>}
+  %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x18x19x5xi32>
   return %0 : tensor<2x18x19x5xi32>
 }
 
@@ -26,17 +30,20 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor
 
 // CHECK-LABEL: @transpose_conv2d_quantized_padded
 func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) {
-  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i32}
+  // CHECK-DAG: %[[INPUT_ZP:.+]]  = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
+  // CHECK-DAG: %[[WEIGHT_ZP:.+]]  = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
+  // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32}
   // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32}
-  // CHECK: tosa.conv2d %arg0, %1, %arg2
+  // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
   // CHECK-SAME: dilation = array<i64: 1, 1>, pad = array<i64: 3, 4, 8, 9>,
-  // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {
+  // CHECK-SAME: stride = array<i64: 1, 1>}
+  %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
     acc_type = i32,
     out_pad = array<i64: 1, 2, 3, 4>,
-    quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>,
     out_shape = array<i64: -1, -1, -1, -1>,
-    stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x21x26x5xi32>
+    stride = array<i64: 1, 1>} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32>
   return %0 : tensor<2x21x26x5xi32>
 }
 
@@ -71,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor<
   // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
   // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
   // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32>
   %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32>
   return %1 : tensor<2x?x?x5xf32>
 }
@@ -98,7 +105,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
 
   // Manipulate the final shape.
   // CHECK-DAG: %[[BIAS:.+]]  = "tosa.const"() <{value = dense<0> : tensor<30xi32>}
-  // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>}
+  // CHECK-DAG: %[[INPUT_ZP:.+]]  = "tosa.const"() <{value = dense<-22> : tensor<1xi8>}
+  // CHECK-DAG: %[[WEIGHT_ZP:.+]]  = "tosa.const"() <{value = dense<42> : tensor<1xi8>}
+  // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>}
   // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 2, 18, 16, 2, 3, 5>}
   // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]]
   // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]]
@@ -107,7 +116,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1
   // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]]
   // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2
   // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]]
-  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32>
+  %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: -1, -1, -1, -1>, stride = array<i64: 2, 3>} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x35x47x5xi32>
   return %0 : tensor<2x35x47x5xi32>
 }
 
@@ -135,12 +146,13 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
   // CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]]
   // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array<i64: 1, 1, 1, 1>}
   // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]]
-  %2 =  tosa.transpose_conv2d %arg0, %arg1, %arg2 {
+  %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8>
+  %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8>
+  %2 =  tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {
     acc_type = i32,
     out_pad = array<i64: 2, 0, 0, 1>,
     out_shape = array<i64: 1, -1, -1, 1>,
-    stride = array<i64: 1, 2>,
-    quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>} :
-    (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>) -> tensor<1x19x2x1xi32>
+    stride = array<i64: 1, 2>} :
+    (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
   "func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
 }



More information about the Mlir-commits mailing list