[Mlir-commits] [mlir] [mlir[tosa] Switch zero point of avgpool2d to input variable type (PR #128983)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 26 17:45:55 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Tai Ly (Tai78641)

<details>
<summary>Changes</summary>

This commit changes the TOSA operator AvgPool2d's zero point attributes to inputs to align with TOSA 1.0 spec.

Change-Id: Ieee6ba824327913bc8462cbcb7a74c0b6dd53d21

---

Patch is 44.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/128983.diff


15 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+6-6) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+11-3) 
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+15-6) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+64-47) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+2-2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+12-4) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir (+5-5) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-1) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+50-6) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+24-24) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+18-6) 
- (modified) mlir/test/Dialect/Tosa/profile_all_unsupported.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+2-2) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+16-4) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 2617a902c3a0d..941a7fe2b23ad 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -5,9 +5,9 @@ profileComplianceMap = {
      {{{Profile::pro_int}, {{i8T, i32T}}},
       {{Profile::pro_fp}, {{fp16T, i32T}, {fp32T, i32T}}}}},
     {"tosa.avg_pool2d",
-     {{{Profile::pro_int}, {{i8T, i32T, i8T}}},
+     {{{Profile::pro_int}, {{i8T, i8T, i8T, i32T, i8T}}},
       {{Profile::pro_fp},
-       {{fp16T, fp16T, fp16T}, {fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T}}}}},
+       {{fp16T, fp16T, fp16T, fp16T, fp16T}, {fp16T, fp16T, fp16T, fp32T, fp16T}, {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
     {"tosa.conv2d",
      {{{Profile::pro_int}, {{i8T, i8T, i32T, i32T, i32T}}},
       {{Profile::pro_fp},
@@ -243,10 +243,10 @@ extensionComplianceMap = {
       {{Extension::fp8e5m2}, {{fp8e5m2T, i32T}}},
       {{Extension::bf16}, {{bf16T, i32T}}}}},
     {"tosa.avg_pool2d",
-     {{{Extension::int16}, {{i16T, i32T, i16T}}},
-      {{Extension::fp8e4m3}, {{fp8e4m3T, fp16T, fp8e4m3T}}},
-      {{Extension::fp8e5m2}, {{fp8e5m2T, fp16T, fp8e5m2T}}},
-      {{Extension::bf16}, {{bf16T, fp32T, bf16T}}}}},
+     {{{Extension::int16}, {{i16T, i16T, i16T, i32T, i16T}}},
+      {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T}}},
+      {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T}}},
+      {{Extension::bf16}, {{bf16T, bf16T, bf16T, fp32T, bf16T}}}}},
     {"tosa.conv2d",
      {{{Extension::int4}, {{i8T, i4T, i32T, i32T, i32T}}},
       {{Extension::int16}, {{i16T, i8T, i48T, i48T, i48T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..27a4bc57d1964 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -79,12 +79,12 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$output_zp,
     Tosa_IntArrayAttr2:$kernel,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$pad,
-    TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<I32Attr>:$input_zp,
-    OptionalAttr<I32Attr>:$output_zp
+    TypeAttrOf<Tosa_AccType>:$acc_type
   );
 
   let results = (outs
@@ -97,6 +97,14 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
   ];
 
   let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
+
+  let extraClassDeclaration = [{
+    LogicalResult getInputZeroPoint(int64_t &zp);
+    LogicalResult getOutputZeroPoint(int64_t &zp);
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 006e35806d64f..9c8d4f75c6e37 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -804,6 +804,15 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
       return failure();
     SmallVector<Value> dynamicDims = *dynamicDimsOr;
 
+    int64_t inputZpVal;
+    int64_t outputZpVal;
+    if (op.getInputZeroPoint(inputZpVal).failed() ||
+        op.getOutputZeroPoint(outputZpVal).failed()) {
+      (void)rewriter.notifyMatchFailure(
+          op, "zero points could not be statically determined");
+      return failure();
+    }
+
     // Apply padding as necessary.
     llvm::SmallVector<int64_t> pad;
     pad.resize(2, 0);
@@ -923,9 +932,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply an offset
             // for the input zp value.
-            if (op.getInputZp()) {
-              auto inputZp =
-                  rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr());
+            if (inputZpVal != 0) {
+              auto inputZp = rewriter.create<arith::ConstantOp>(
+                  loc, b.getIntegerAttr(accETy, inputZpVal));
               Value offset =
                   rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp);
               poolVal =
@@ -977,9 +986,9 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
 
             // If we have quantization information we need to apply output
             // zeropoint.
-            if (op.getOutputZp()) {
-              auto outputZp =
-                  rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr());
+            if (outputZpVal != 0) {
+              auto outputZp = rewriter.create<arith::ConstantOp>(
+                  loc, b.getIntegerAttr(scaled.getType(), outputZpVal));
               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
                            .getResult();
             }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63e183538c257 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -455,18 +455,10 @@ LogicalResult tosa::ArgMaxOp::verify() {
 }
 
 LogicalResult tosa::AvgPool2dOp::verify() {
-  auto inputType = llvm::cast<ShapedType>(getInput().getType());
-
-  auto inputETy = inputType.getElementType();
-  auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
-
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
-    inputETy = quantType.getStorageType();
-
-  if (auto quantType =
-          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
-    resultETy = quantType.getStorageType();
+  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
+  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
+  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
+  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
 
   auto accType = getAccType();
   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
@@ -481,6 +473,28 @@ LogicalResult tosa::AvgPool2dOp::verify() {
   if (inputETy.isF32() && !accType.isF32())
     return emitOpError("accumulator type for f32 tensor is not f32");
 
+  if (inputETy != inputZpETy)
+    return emitOpError("expect both input and its zero point are the same "
+                       "element type, got ")
+           << inputETy << " and " << inputZpETy;
+
+  if (resultETy != outputZpETy)
+    return emitOpError("expect both output and its zero point are the same "
+                       "element type, got ")
+           << resultETy << " and " << outputZpETy;
+
+  int64_t inputZpVal;
+  if (getInputZeroPoint(inputZpVal).succeeded() &&
+      verifyInputZeroPoint(inputZpVal).failed())
+    return emitOpError(
+        "input zero point must be zero for non-int8 integer types");
+
+  int64_t outputZpVal;
+  if (getOutputZeroPoint(outputZpVal).succeeded() &&
+      verifyOutputZeroPoint(outputZpVal).failed())
+    return emitOpError(
+        "output zero point must be zero for non-int8 integer types");
+
   if ((inputETy.isF32() && resultETy.isF32()) ||
       (inputETy.isF16() && resultETy.isF16()) ||
       (inputETy.isBF16() && resultETy.isBF16()) ||
@@ -629,27 +643,37 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
 }
 
 /// Both the tosa.avg_pool2d and unary ops use the same
-/// UnaruOpQuantizationAttr but avg_pool operator has its own builder as it
+/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
 /// has additional parameters not part of the unary ops.
 static void
 buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                               Type outputType, Value input,
                               DenseArrayAttr kernel, DenseArrayAttr stride,
                               DenseArrayAttr pad, TypeAttr accType) {
-  result.addOperands(input);
+  const Location loc{result.location};
+  int64_t inputZp{0};
+  int64_t outputZp{0};
+
+  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
+  if (quantAttr) {
+    inputZp = quantAttr.getInputZp();
+    outputZp = quantAttr.getOutputZp();
+  }
+  const std::optional<Value> inputZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), inputZp);
+  assert(
+      inputZpOp.has_value() &&
+      "Failed to create input zero point tensor for quantized AVG_POOL2D op");
+  const std::optional<Value> outputZpOp =
+      createZeroPointTensor(builder, loc, outputType, outputZp);
+  assert(
+      outputZpOp.has_value() &&
+      "Failed to create output zero point tensor for quantized AVG_POOL2D op");
+  result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
   result.addAttribute("kernel", kernel);
   result.addAttribute("stride", stride);
   result.addAttribute("pad", pad);
   result.addAttribute("acc_type", accType);
-  auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
-  if (quantAttr) {
-    result.addAttribute("input_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getInputZp())));
-    result.addAttribute("output_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getOutputZp())));
-  }
   result.types.push_back(outputType);
 }
 
@@ -1425,13 +1449,6 @@ static LogicalResult getZeroPoint(T op, Value val, int64_t &zp) {
 
 template <typename T>
 static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
-  // TODO clean it up when the entire zero point (attribute -> input tensor
-  // type) change is done. Remaining Matmul, Rescale, Negate, and AvgPool2D.
-  if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
-                !std::is_same_v<T, DepthwiseConv2DOp> &&
-                !std::is_same_v<T, TransposeConv2DOp>)
-    return failure();
-
   Type zpElemType = getElementTypeOrSelf(val);
 
   if (!zpElemType.isIntOrFloat())
@@ -1446,24 +1463,24 @@ static LogicalResult verifyZeroPoint(T op, Value val, int64_t &zp) {
   return success();
 }
 
-#define ZERO_POINT_HELPER(OP)                                                  \
-  LogicalResult tosa::OP::getInputZeroPoint(int64_t &zp) {                     \
-    return getZeroPoint(*this, getInputZp(), zp);                              \
+#define ZERO_POINT_HELPER(OP, OPERAND_NAME)                                    \
+  LogicalResult tosa::OP::get##OPERAND_NAME##ZeroPoint(int64_t &zp) {          \
+    return getZeroPoint(*this, get##OPERAND_NAME##Zp(), zp);                   \
   }                                                                            \
-  LogicalResult tosa::OP::getWeightZeroPoint(int64_t &zp) {                    \
-    return getZeroPoint(*this, getWeightZp(), zp);                             \
-  }                                                                            \
-  LogicalResult tosa::OP::verifyInputZeroPoint(int64_t zp) {                   \
-    return verifyZeroPoint(*this, getInputZp(), zp);                           \
-  }                                                                            \
-  LogicalResult tosa::OP::verifyWeightZeroPoint(int64_t zp) {                  \
-    return verifyZeroPoint(*this, getWeightZp(), zp);                          \
-  }
-
-ZERO_POINT_HELPER(Conv2DOp)
-ZERO_POINT_HELPER(Conv3DOp)
-ZERO_POINT_HELPER(DepthwiseConv2DOp)
-ZERO_POINT_HELPER(TransposeConv2DOp)
+  LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) {        \
+    return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp);                \
+  }
+
+ZERO_POINT_HELPER(Conv2DOp, Input)
+ZERO_POINT_HELPER(Conv2DOp, Weight)
+ZERO_POINT_HELPER(Conv3DOp, Input)
+ZERO_POINT_HELPER(Conv3DOp, Weight)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Input)
+ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight)
+ZERO_POINT_HELPER(TransposeConv2DOp, Input)
+ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
+ZERO_POINT_HELPER(AvgPool2dOp, Input)
+ZERO_POINT_HELPER(AvgPool2dOp, Output)
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 1d8aaa65c2976..345616c9563b5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -58,6 +58,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::ConcatOp op) {
 template <>
 void ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dOp op) {
   addValue(op.getInput());
+  addValue(op.getInputZp());
+  addValue(op.getOutputZp());
   addType(op.getAccType());
   addValue(op.getOutput());
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..3bd1049f1a0ce 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics
 
 // CHECK-LABEL: @avg_pool2d_with_unsupported_quant_type
-func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
+func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
   // expected-error at +1 {{failed to legalize operation 'tosa.avg_pool2d'}}
-  %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
+  %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
   return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 332b706871547..386855420f48e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -290,7 +290,9 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
   // CHECK:   %[[FLT:.+]] = arith.sitofp %[[CAST]]
   // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
   // CHECK:   linalg.yield %[[DIV]]
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> tensor<1x5x33x62xf32>
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x33x62xf32>
   return %0 : tensor<1x5x33x62xf32>
 }
 
@@ -375,7 +377,9 @@ func.func @avg_pool_f16_f32acc(%arg0: tensor<1x6x34x62xf16>) -> (tensor<1x5x33x6
   // CHECK:   %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
   // CHECK:   %[[TRUNC:.+]] = arith.truncf %[[DIV]]
   // CHECK:   linalg.yield %[[TRUNC]]
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>) -> tensor<1x5x33x62xf16>
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x33x62xf16>
   return %0 : tensor<1x5x33x62xf16>
 }
 
@@ -416,7 +420,9 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
   // CHECK: %[[CLAMP:.+]] = arith.minsi %[[CMAX]], %[[LOW]]
   // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
   // CHECK: linalg.yield %[[TRUNC]]
-  %0 = tosa.avg_pool2d %arg0 {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> tensor<1x5x33x62xi8>
+  %input_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x33x62xi8>
   return %0 : tensor<1x5x33x62xi8>
 }
 
@@ -439,7 +445,9 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
   // CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
   // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
   // CHECK: %[[GENERIC:.+]] = linalg.generic
-  %0 = tosa.avg_pool2d %arg0 {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> tensor<?x5x33x62xf32>
+  %input_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %output_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.avg_pool2d %arg0, %input_zp, %output_zp {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x33x62xf32>
   return %0 : tensor<?x5x33x62xf32>
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
index 73da2810abe04..ecd5c792e08b6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-pipeline.mlir
@@ -23,18 +23,18 @@ func.func @tensor_with_unknown_rank(%arg0: tensor<*xi8>) -> tensor<*xi8> {
 // -----
 
 // check that tosa verify kick in
-func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> {
+func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>) -> tensor<1x7x7x9xf32> {
   // expected-error at +1 {{'tosa.avg_pool2d' op operand #0 must be 4-d tosa-conformant tensor, but got 'tensor<1x0x?x9xf32>'}}
-    %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
-      : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32>
+    %0 = "tosa.avg_pool2d"(%arg0, %arg1, %arg2) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>}
+      : (tensor<1x0x?x9xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x7x7x9xf32>
     return %0 : tensor<1x7x7x9xf32>
 }
 
 // -----
 
 // check that --tosa-to-linalg kick in
-func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
+func.func @avg_pool2d_with_unsupported_quant_type(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
   // expected-error at +1 {{failed to legalize operation 'tosa.avg_pool2d'}}
-  %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/128983


More information about the Mlir-commits mailing list