[Mlir-commits] [mlir] [mlir][tosa] Add TOSA Max Pool 2D Adaptive (PR #191225)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 9 08:53:57 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Iliyan Georgiev (iliyan-georgiev-arm)

<details>
<summary>Changes</summary>

Implements:
- Operator definition
- Operator verifier
- Validation
- Tests
- Adds NoMemoryEffect to AvgPool2dAdaptive

Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@<!-- -->arm.com>
Change-Id: I7550cc588ffc0da684605d67db71d989fb51da62

---

Patch is 26.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/191225.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13) 
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37-1) 
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+53-3) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+9) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-3) 
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+116) 
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+47) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+11) 
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+24) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a575024a6144a..d3e2cd129028e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -65,6 +65,11 @@ profileComplianceMap = {
       {{Profile::pro_fp},
        {{{fp16T, fp16T}, SpecificationVersion::V_1_0},
         {{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.max_pool2d_adaptive",
+     {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Profile::pro_fp},
+       {{{fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+        {{fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.transpose_conv2d",
      {{{Profile::pro_int},
        {{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
@@ -657,6 +662,14 @@ extensionComplianceMap = {
       {{Extension::fp8e5m2},
        {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
       {{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+    {"tosa.max_pool2d_adaptive",
+     {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e4m3},
+       {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::fp8e5m2},
+       {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+      {{Extension::bf16},
+       {{{bf16T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}}}},
     {"tosa.rfft2d",
      {{{Extension::fft},
        {{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5ac91e6b65457..45d1388a28749 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -124,7 +124,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d", [NoMemoryEffect]> {
 //===----------------------------------------------------------------------===//
 // Operator: avg_pool2d_adaptive
 //===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dAdaptiveOp : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive"> {
+def Tosa_AvgPool2dAdaptiveOp
+    : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive", [NoMemoryEffect]> {
   let summary = "Performs average pooling on the input with shape operands.";
 
   let description = [{
@@ -524,6 +525,41 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d", [Pure]> {
   let hasCustomAssemblyFormat = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: max_pool2d_adaptive
+//===----------------------------------------------------------------------===//
+def Tosa_MaxPool2dAdaptiveOp
+    : Tosa_InferShapedTypeOp<"max_pool2d_adaptive", [Pure]> {
+  let summary = "Performs max pooling on the input.";
+
+  let description = [{
+    This performs a max pooling over the given input tensor. A sliding window of
+    size given by <kernel size> is passed over the input tensor, with the
+    maximum value being placed in the output tensor.
+    Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride, 
+    pad arguments as inputs rather than attributes.
+  }];
+
+  let arguments =
+      (ins Tosa_Tensor4D:$input, Rank2TosaShape:$kernel, Rank2TosaShape:$stride,
+          Rank4TosaShape:$pad,
+
+          DefaultValuedAttr<
+              Tosa_NanPropagationModeAttr,
+              "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode);
+
+  let results = (outs Tosa_Tensor4D:$output);
+
+  list<Availability> availability =
+      [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+       Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2,
+                  Tosa_EXT_BF16]>,
+  ];
+
+  let hasVerifier = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 29318023092a1..3bf878304429e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -486,6 +486,15 @@ void MaxPool2dOp::print(OpAsmPrinter &parser) {
   printWithNanPropagationHandling(parser, *this);
 }
 
+ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
+                                       OperationState &result) {
+  return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
+  printWithNanPropagationHandling(parser, *this);
+}
+
 ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
   return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
 }
@@ -1228,9 +1237,8 @@ struct AdaptivePoolingConstShapeValues {
 
 template <typename T>
 static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp =
-    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
-    // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
-    ;
+    std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+    std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
 
 template <typename T,
           typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
@@ -4085,6 +4093,33 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
                                  inferredReturnShapes);
 }
 
+LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    MaxPool2dAdaptiveOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput().getType());
+
+  llvm::SmallVector<int64_t> kernelValues;
+  llvm::SmallVector<int64_t> strideValues;
+  llvm::SmallVector<int64_t> padValues;
+  if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
+                                kernelValues) &&
+      tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+                                strideValues) &&
+      tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
+    return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
+                                   padValues, inferredReturnShapes);
+  }
+
+  llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
+  if (inputShape.hasRank()) {
+    outputShape[0] = inputShape.getDimSize(0);
+    outputShape[3] = inputShape.getDimSize(3);
+  }
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  return success();
+}
+
 LogicalResult MaxPool2dOp::verify() {
   if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
                                     /* outType = */ getOutput().getType())))
@@ -4096,6 +4131,21 @@ LogicalResult MaxPool2dOp::verify() {
   return success();
 }
 
+LogicalResult MaxPool2dAdaptiveOp::verify() {
+  if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+                                    /* outType = */ getOutput().getType())))
+    return failure();
+
+  AdaptivePoolingConstShapeValues values;
+  extractAdaptivePoolingConstShapeOperands(*this, values);
+
+  if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
+                                 values.pad, getInput(), getOutput())))
+    return failure();
+
+  return success();
+}
+
 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     DepthwiseConv2DOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 78bf700597c3c..01c85be4f704f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -88,6 +88,14 @@ ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
   return success();
 }
 
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
+  addValue(op.getInput());
+  addValue(op.getOutput());
+  return success();
+}
+
 template <typename T>
 LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
   addValue(op.getInput());
@@ -288,6 +296,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
   POPULATE_PROFILE_INFO_CUSTOM(Variable)
   POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
   POPULATE_PROFILE_INFO_CUSTOM(Dim)
+  POPULATE_PROFILE_INFO_CUSTOM(MaxPool2dAdaptive)
 
   // For the most of tosa operators, all operands are profile/extension related
   // and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6169003881487..8c00603d7abb4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -359,9 +359,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
 
   template <typename T>
   static constexpr bool IsSupportedAdaptivePoolOp =
-      std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
-      // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
-      ;
+      std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+      std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
 
   template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
                                                 int>::type = 0>
@@ -817,6 +816,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
   CHECK_SIZES(MatMul);
   CHECK_SIZES(MatmulTBlockScaled);
   CHECK_SIZES(MaxPool2d);
+  CHECK_SIZES(MaxPool2dAdaptive);
   CHECK_SIZES(RFFT2d);
   // Scatter/Gather Operators
   CHECK_SIZES(Gather);
@@ -918,6 +918,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
       failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
       failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
       failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
+      failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
       failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
       failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
       failed(levelCheckConv2DBlockScaled(op))) {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b3bdb02c20103..ca4d2dca0e7c9 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -929,6 +929,122 @@ func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x3
 
 // -----
 
+func.func @test_maxpool2d_adaptive_kernel_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+  return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_kernel_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+  return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_first(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+  
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_second(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 8193, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_third(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 8193, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_forth(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+  // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+  // This is a workaround for the above so that we can level check the padding.
+  %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+         (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+  return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
 func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
   // expected-error at +1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL (8192), got 16384}}
   %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e80d3d84a8105..b30e92c4a9621 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -253,6 +253,53 @@ func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8
   return %0 : tensor<1x32x32x8xf16>
 }
 
+// CHECK-LABEL: max_pool2d_adaptive_f32
+func.func @test_max_pool2d_adaptive_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+  return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_bf16
+func.func @test_max_pool2d_adaptive_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+  %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xbf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xbf16>
+  return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_f16
+func.func @test_max_pool2d_adaptive_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+  %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %stride = tosa.const_...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/191225


More information about the Mlir-commits mailing list