[Mlir-commits] [mlir] [mlir][tosa] Add TOSA Avg Pool 2D Adaptive (PR #190200)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 2 08:48:10 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Iliyan Georgiev (iliyan-georgiev-arm)

<details>
<summary>Changes</summary>

Signed-off-by: Deeptanshu Sekhri <deeptanshu.sekhri@<!-- -->arm.com>
Co-authored-by: Iliyan Georgiev <iliyan.georgiev@<!-- -->arm.com>

---

Patch is 79.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/190200.diff


17 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+22) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+12-1) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+44) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+231-89) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+12) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+54) 
- (modified) mlir/test/Dialect/Tosa/availability.mlir (+14) 
- (modified) mlir/test/Dialect/Tosa/dynamic_extension.mlir (+11) 
- (modified) mlir/test/Dialect/Tosa/invalid.mlir (+193) 
- (modified) mlir/test/Dialect/Tosa/invalid_extension.mlir (+9-47) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+60) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+72) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir (+11) 
- (modified) mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir (+12-15) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+83) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+50) 
- (added) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-invalid.mlir (+27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index 7ea0d134941c7..a575024a6144a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -13,6 +13,16 @@ profileComplianceMap = {
        {{{fp16T, fp16T, fp16T, fp16T, fp16T}, SpecificationVersion::V_1_0},
         {{fp16T, fp16T, fp16T, fp32T, fp16T}, SpecificationVersion::V_1_0},
         {{fp32T, fp32T, fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.avg_pool2d_adaptive",
+     {{{Profile::pro_int},
+       {{{i8T, i8T, i8T, i32T, i8T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{fp16T, fp16T, fp16T, fp16T, fp16T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp16T, fp16T, fp16T, fp32T, fp16T},
+         SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, fp32T, fp32T, fp32T, fp32T},
+         SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.conv2d",
      {{{Profile::pro_int},
        {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
@@ -517,6 +527,18 @@ extensionComplianceMap = {
          SpecificationVersion::V_1_0}}},
       {{Extension::bf16},
        {{{bf16T, bf16T, bf16T, fp32T, bf16T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.avg_pool2d_adaptive",
+     {{{Extension::int16},
+       {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3},
+       {{{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T, fp8e4m3T},
+         SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T, fp8e5m2T},
+         SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, bf16T, bf16T, fp32T, bf16T},
+         SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.conv2d",
      {{{Extension::int4},
        {{{i8T, i4T, i32T, i8T, i4T, i32T, i32T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 9df17ed89b818..1f05aee3e5eec 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -167,7 +167,7 @@ def Tosa_MatMulOpQuantInfoBuilder : OpBuilder<
   }]>;
 
 // Both the tosa.avg_pool2d and unary ops use the same
-// UnaruOpQuantizationAttr but the avg_pool operator has its own builder as it
+// UnaryOpQuantizationAttr but the avg_pool operator has its own builder as it
 // has additional parameters not part of the unary ops.
 def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
   (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
@@ -178,6 +178,17 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
                                   input, kernel, stride, pad, acc_type);
   }]>;
 
+def Tosa_AvgPool2dAdaptiveOpQuantInfoBuilder
+    : OpBuilder<(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
+                    "::mlir::DenseI64ArrayAttr":$kernel,
+                    "::mlir::DenseI64ArrayAttr":$stride,
+                    "::mlir::DenseI64ArrayAttr":$pad,
+                    "::mlir::TypeAttr":$acc_type),
+                [{
+    buildAvgPool2dAdaptiveOpWithQuantInfo($_builder, $_state, outputType,
+                                          input, kernel, stride, pad, acc_type);
+  }]>;
+
 // This builder is called on single-parameter negate operators that have a scale
 // relationship between their input and output, expressed by the
 // UnaryOpQuantizationAttr.
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index cab2bccfc27b3..c1a4a98f03f9c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -121,6 +121,50 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d", [NoMemoryEffect]> {
       "operands attr-dict `:` functional-type(operands, results)";
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: avg_pool2d_adaptive
+//===----------------------------------------------------------------------===//
+def Tosa_AvgPool2dAdaptiveOp : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive"> {
+  let summary = "Performs average pooling on the input with shape operands.";
+
+  let description = [{
+    This performs an average pooling over the given input tensor. A sliding
+    window of size given by <kernel size> is passed over the input tensor, with
+    the mean value being placed in the output tensor. When calculating the
+    average, only the number of valid input tensor values, but not padding, are
+    used to calculate the divisor. Compared to avg_pool2d, the kernel/stride/
+    pad values are provided as inputs.
+  }];
+
+  let arguments = (ins Tosa_Tensor4D:$input,
+      Tosa_ScalarIntOrFloatTensor:$input_zp,
+      Tosa_ScalarIntOrFloatTensor:$output_zp, Rank2TosaShape:$kernel,
+      Rank2TosaShape:$stride, Rank4TosaShape:$pad,
+      TypeAttrOf<Tosa_AccType>:$acc_type);
+
+  let results = (outs Tosa_Tensor4D:$output);
+
+  list<Availability> availability =
+      [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+       Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2,
+                  Tosa_EXT_BF16]>,
+  ];
+
+  let builders = [Tosa_AvgPool2dAdaptiveOpQuantInfoBuilder];
+
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getInputZeroPoint();
+    FailureOr<int64_t> getOutputZeroPoint();
+    LogicalResult verifyInputZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
+
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: conv2d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6072aecdf347b..57353b683de78 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 
 #include <numeric>
+#include <type_traits>
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -1077,128 +1078,192 @@ LogicalResult tosa::ArgMaxOp::verify() {
   return success();
 }
 
-template <typename T>
-static LogicalResult verifyPoolingOp(T op) {
-  const llvm::ArrayRef<int64_t> kernel = op.getKernel();
-  if (llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
-    return op.emitOpError("expect all kernel values to be >= 1, got ")
+static LogicalResult verifyPoolingOpImpl(Operation *op,
+                                         ArrayRef<int64_t> kernel,
+                                         ArrayRef<int64_t> strides,
+                                         ArrayRef<int64_t> padding, Value input,
+                                         Value output) {
+  const bool hasKernel = kernel.size() > 0;
+  const bool hasStrides = strides.size() > 0;
+  const bool hasPad = padding.size() > 0;
+
+  if (hasKernel && llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
+    return op->emitOpError("expect all kernel values to be >= 1, got ")
            << kernel;
 
-  const llvm::ArrayRef<int64_t> strides = op.getStride();
-  if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
-    return op.emitOpError("expect all stride values to be >= 1, got ")
+  if (hasStrides && llvm::any_of(strides, [](int64_t s) { return s < 1; }))
+    return op->emitOpError("expect all stride values to be >= 1, got ")
            << strides;
 
-  const llvm::ArrayRef<int64_t> padding = op.getPad();
-  if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
-    return op.emitOpError("expect all padding values to be >= 0, got ")
+  if (hasPad && llvm::any_of(padding, [](int64_t p) { return p < 0; }))
+    return op->emitOpError("expect all padding values to be >= 0, got ")
            << padding;
 
-  // Padding must be less than kernel size to avoid a divide-by-zero
-  const int64_t kernelX = kernel[1];
-  const int64_t padLeft = padding[2];
-  const int64_t padRight = padding[3];
-  if (padRight >= kernelX || padLeft >= kernelX)
-    return op.emitOpError("expected left/right padding to be less than the "
-                          "width of the kernel, got pad_left=")
-           << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX;
-
-  const int64_t kernelY = kernel[0];
-  const int64_t padTop = padding[0];
-  const int64_t padBottom = padding[1];
-  if (padTop >= kernelY || padBottom >= kernelY)
-    return op.emitOpError("expected top/bottom padding to be less than the "
-                          "height of the kernel, got pad_top=")
-           << padTop << ", pad_bottom=" << padBottom
-           << ", kernel_y=" << kernelY;
-
-  const auto inputType =
-      llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
-  const auto outputType =
-      llvm::dyn_cast<RankedTensorType>(op.getResult().getType());
+  if (hasKernel && hasPad) {
+    // Padding must be less than kernel size to avoid a divide-by-zero
+    const int64_t kernelX = kernel[1];
+    const int64_t padLeft = padding[2];
+    const int64_t padRight = padding[3];
+    if (padRight >= kernelX || padLeft >= kernelX)
+      return op->emitOpError("expected left/right padding to be less than the "
+                             "width of the kernel, got pad_left=")
+             << padLeft << ", pad_right=" << padRight
+             << ", kernel_x=" << kernelX;
+
+    const int64_t kernelY = kernel[0];
+    const int64_t padTop = padding[0];
+    const int64_t padBottom = padding[1];
+    if (padTop >= kernelY || padBottom >= kernelY)
+      return op->emitOpError("expected top/bottom padding to be less than the "
+                             "height of the kernel, got pad_top=")
+             << padTop << ", pad_bottom=" << padBottom
+             << ", kernel_y=" << kernelY;
+  }
+
+  const auto inputType = llvm::dyn_cast<RankedTensorType>(input.getType());
+  const auto outputType = llvm::dyn_cast<RankedTensorType>(output.getType());
   if (!inputType || !outputType)
     return success();
 
-  const auto verifyOutputSize =
-      [&op](const int64_t inputSize, const int64_t outputSize,
-            const int64_t kernelSize, const int64_t strideSize,
-            const int64_t padBefore, const int64_t padAfter,
-            const llvm::StringRef dimName, const llvm::StringRef dimAxis,
-            const llvm::StringRef padBeforeName,
-            const llvm::StringRef padAfterName) -> LogicalResult {
-    if (ShapedType::isDynamic(inputSize))
-      return success();
-
-    const std::optional<int64_t> calculatedOutSizeMinusOne =
-        idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
-    if (!calculatedOutSizeMinusOne.has_value())
-      return op.emitOpError("expected input_")
-             << dimName << " + pad_" << padBeforeName << " + pad_"
-             << padAfterName << " - kernel_" << dimAxis
-             << " to be wholly divisible by stride_" << dimAxis << ", got ("
-             << inputSize << " + " << padBefore << " + " << padAfter << " - "
-             << kernelSize << ") / " << strideSize;
-
-    const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
-    if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
-      return op.emitOpError("calculated output ")
-             << dimName << " did not match expected: " << "calculated="
-             << calculatedOutSize << ", expected=" << outputSize;
-
-    return success();
-  };
+  if (hasKernel && hasStrides && hasPad) {
+    const auto verifyOutputSize =
+        [op](const int64_t inputSize, const int64_t outputSize,
+             const int64_t kernelSize, const int64_t strideSize,
+             const int64_t padBefore, const int64_t padAfter,
+             const llvm::StringRef dimName, const llvm::StringRef dimAxis,
+             const llvm::StringRef padBeforeName,
+             const llvm::StringRef padAfterName) -> LogicalResult {
+      if (ShapedType::isDynamic(inputSize))
+        return success();
+
+      const std::optional<int64_t> calculatedOutSizeMinusOne =
+          idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
+      if (!calculatedOutSizeMinusOne.has_value())
+        return op->emitOpError("expected input_")
+               << dimName << " + pad_" << padBeforeName << " + pad_"
+               << padAfterName << " - kernel_" << dimAxis
+               << " to be wholly divisible by stride_" << dimAxis << ", got ("
+               << inputSize << " + " << padBefore << " + " << padAfter << " - "
+               << kernelSize << ") / " << strideSize;
+
+      const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+      if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
+        return op->emitOpError("calculated output ")
+               << dimName << " did not match expected: " << "calculated="
+               << calculatedOutSize << ", expected=" << outputSize;
 
-  if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1),
-                              kernel[0], strides[0], padding[0], padding[1],
-                              "height", "y", "top", "bottom")))
-    return failure();
+      return success();
+    };
 
-  if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2),
-                              kernel[1], strides[1], padding[2], padding[3],
-                              "width", "x", "left", "right")))
-    return failure();
+    if (failed(verifyOutputSize(inputType.getDimSize(1),
+                                outputType.getDimSize(1), kernel[0], strides[0],
+                                padding[0], padding[1], "height", "y", "top",
+                                "bottom")))
+      return failure();
 
+    if (failed(verifyOutputSize(
+            inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
+            strides[1], padding[2], padding[3], "width", "x", "left", "right")))
+      return failure();
+  }
   return success();
 }
 
-LogicalResult tosa::AvgPool2dOp::verify() {
-  if (failed(verifyPoolingOp(*this)))
-    return failure();
+template <typename T>
+static LogicalResult verifyPoolingOp(T op) {
+  return verifyPoolingOpImpl(op.getOperation(), op.getKernel(), op.getStride(),
+                             op.getPad(), op.getInput(), op.getOutput());
+}
 
-  const Type inputETy = getStorageElementTypeOrSelf(getInput().getType());
-  const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType());
-  const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType());
-  const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType());
+template <typename T>
+static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op) {
+  const Type inputETy = getStorageElementTypeOrSelf(op.getInput().getType());
+  const Type resultETy = getStorageElementTypeOrSelf(op.getOutput().getType());
+  const Type inputZpETy =
+      getStorageElementTypeOrSelf(op.getInputZp().getType());
+  const Type outputZpETy =
+      getStorageElementTypeOrSelf(op.getOutputZp().getType());
 
-  auto accType = getAccType();
+  auto accType = op.getAccType();
   if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
-    return emitOpError("accumulator type for integer tensor is not i32");
+    return op.emitOpError("accumulator type for integer tensor is not i32");
 
   if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
-    return emitOpError("accumulator type for f16 tensor is not f16/f32");
+    return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
 
   if (inputETy.isBF16() && !accType.isF32())
-    return emitOpError("accumulator type for bf16 tensor is not f32");
+    return op.emitOpError("accumulator type for bf16 tensor is not f32");
 
   if (inputETy.isF32() && !accType.isF32())
-    return emitOpError("accumulator type for f32 tensor is not f32");
+    return op.emitOpError("accumulator type for f32 tensor is not f32");
 
   if (inputETy != inputZpETy)
-    return emitOpError("expect both input and its zero point are the same "
-                       "element type, got ")
+    return op.emitOpError("expect both input and its zero point are the same "
+                          "element type, got ")
            << inputETy << " and " << inputZpETy;
 
   if (resultETy != outputZpETy)
-    return emitOpError("expect both output and its zero point are the same "
-                       "element type, got ")
+    return op.emitOpError("expect both output and its zero point are the same "
+                          "element type, got ")
            << resultETy << " and " << outputZpETy;
 
-  FailureOr<int64_t> maybeIZp = getInputZeroPoint();
-  if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
+  FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
+  if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
     return failure();
 
-  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
-  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+  FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
+  if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
+    return failure();
+
+  return success();
+}
+
+namespace {
+struct AdaptivePoolingConstShapeValues {
+  llvm::SmallVector<int64_t> kernel;
+  llvm::SmallVector<int64_t> stride;
+  llvm::SmallVector<int64_t> pad;
+};
+} // namespace
+
+template <typename T>
+static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp =
+    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
+    // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
+    ;
+
+template <typename T,
+          typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
+                                  int>::type = 0>
+static void extractAdaptivePoolingConstShapeOperands(
+    T op, AdaptivePoolingConstShapeValues &values) {
+  tosa::getConstShapeValues(op.getKernel().getDefiningOp(), values.kernel);
+  tosa::getConstShapeValues(op.getStride().getDefiningOp(), values.stride);
+  tosa::getConstShapeValues(op.getPad().getDefiningOp(), values.pad);
+}
+
+LogicalResult tosa::AvgPool2dOp::verify() {
+  if (failed(verifyPoolingOp(*this)))
+    return failure();
+  if (failed(verifyAvgPoolCommonTypeAndZpChecks(*this)))
+    return failure();
+  return success();
+}
+
+LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
+  AdaptivePoolingConstShapeValues values;
+  extractAdaptivePoolingConstShapeOperands(*this, values);
+
+  // If pad/stride/kernel are not constant, this is okay, we just can't check
+  // their values. extractAdaptivePoolingConstShapeOperands will return an empty
+  // list for each non CTC input. verifyPoolingOpImpl will need to handle values
+  // not being present, and return success if they cannot be checked.
+
+  if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
+                                 values.pad, getInput(), getOutput())))
+    return failure();
+
+  if (failed(verifyAvgPoolCommonTypeAndZpChecks(*this)))
     return failure();
 
   return success();
@@ -1394,6 +1459,52 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.types.push_back(outputType);
 }
 
+/// This builder mirrors avg_pool2d quant-info handling and materializes
+/// kernel/stride/pad as const_shape operands for avg_pool2d_adaptive.
+static void buildAvgPool2dAdaptiveOpWithQuantInfo(
+    OpBuilder &builder, OperationState &result, Type outputType, Value input,
+    DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad,
+    TypeAttr accType) {
+  const Location loc{result.location};
+  int64_t inputZp{0};
+  int64_t outputZp{0};
+
+  if (auto quantAttr =
+          buildUnaryOpQuantizationAttr(builder, input, outputType)) {
+    inputZp = quantAttr.getInputZp();
+    outputZp = quantAttr.getOutputZp();
+  }
+  const std::optional<Value> inputZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), inputZp);
+  if (!inputZpOp) {
+    (void)emitError(loc,
+                    "Failed to create input zero point tensor for quantized "
+                  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/190200


More information about the Mlir-commits mailing list