[Mlir-commits] [mlir] [mlir][tosa] Add TOSA Max Pool 2D Adaptive (PR #191225)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 9 08:53:57 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Iliyan Georgiev (iliyan-georgiev-arm)
<details>
<summary>Changes</summary>
Implements:
- Operator definition
- Operator verifier
- Validation
- Tests
- Adds NoMemoryEffect to AvgPool2dAdaptive
Signed-off-by: Iliyan Georgiev <Iliyan.Georgiev@<!-- -->arm.com>
Change-Id: I7550cc588ffc0da684605d67db71d989fb51da62
---
Patch is 26.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/191225.diff
10 Files Affected:
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+13)
- (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+37-1)
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+53-3)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp (+9)
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+4-3)
- (modified) mlir/test/Dialect/Tosa/level_check.mlir (+116)
- (modified) mlir/test/Dialect/Tosa/ops.mlir (+47)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+11)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+24)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a575024a6144a..d3e2cd129028e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -65,6 +65,11 @@ profileComplianceMap = {
{{Profile::pro_fp},
{{{fp16T, fp16T}, SpecificationVersion::V_1_0},
{{fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.max_pool2d_adaptive",
+ {{{Profile::pro_int}, {{{i8T, i8T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Profile::pro_fp},
+ {{{fp16T, fp16T}, SpecificationVersion::V_1_1_DRAFT},
+ {{fp32T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.transpose_conv2d",
{{{Profile::pro_int},
{{{i8T, i8T, i32T, i8T, i8T, i32T, i32T}, SpecificationVersion::V_1_0}}},
@@ -657,6 +662,14 @@ extensionComplianceMap = {
{{Extension::fp8e5m2},
{{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
{{Extension::bf16}, {{{bf16T, bf16T}, SpecificationVersion::V_1_0}}}}},
+ {"tosa.max_pool2d_adaptive",
+ {{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e4m3},
+ {{{fp8e4m3T, fp8e4m3T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::fp8e5m2},
+ {{{fp8e5m2T, fp8e5m2T}, SpecificationVersion::V_1_1_DRAFT}}},
+ {{Extension::bf16},
+ {{{bf16T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.rfft2d",
{{{Extension::fft},
{{{fp32T, fp32T, fp32T}, SpecificationVersion::V_1_0}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 5ac91e6b65457..45d1388a28749 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -124,7 +124,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d", [NoMemoryEffect]> {
//===----------------------------------------------------------------------===//
// Operator: avg_pool2d_adaptive
//===----------------------------------------------------------------------===//
-def Tosa_AvgPool2dAdaptiveOp : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive"> {
+def Tosa_AvgPool2dAdaptiveOp
+ : Tosa_InferShapedTypeOp<"avg_pool2d_adaptive", [NoMemoryEffect]> {
let summary = "Performs average pooling on the input with shape operands.";
let description = [{
@@ -524,6 +525,41 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d", [Pure]> {
let hasCustomAssemblyFormat = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: max_pool2d_adaptive
+//===----------------------------------------------------------------------===//
+def Tosa_MaxPool2dAdaptiveOp
+ : Tosa_InferShapedTypeOp<"max_pool2d_adaptive", [Pure]> {
+ let summary = "Performs max pooling on the input.";
+
+ let description = [{
+ This performs a max pooling over the given input tensor. A sliding window of
+ size given by <kernel size> is passed over the input tensor, with the
+ maximum value being placed in the output tensor.
+ Compared to MAX_POOL2D, MAX_POOL2D_ADAPTIVE has the kernel, stride,
+ pad arguments as inputs rather than attributes.
+ }];
+
+ let arguments =
+ (ins Tosa_Tensor4D:$input, Rank2TosaShape:$kernel, Rank2TosaShape:$stride,
+ Rank4TosaShape:$pad,
+
+ DefaultValuedAttr<
+ Tosa_NanPropagationModeAttr,
+ "::mlir::tosa::NanPropagationMode::PROPAGATE">:$nan_mode);
+
+ let results = (outs Tosa_Tensor4D:$output);
+
+ list<Availability> availability =
+ [Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2,
+ Tosa_EXT_BF16]>,
+ ];
+
+ let hasVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 29318023092a1..3bf878304429e 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -486,6 +486,15 @@ void MaxPool2dOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}
+ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
+}
+
+void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
+ printWithNanPropagationHandling(parser, *this);
+}
+
ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
}
@@ -1228,9 +1237,8 @@ struct AdaptivePoolingConstShapeValues {
template <typename T>
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp =
- std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
- // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
- ;
+ std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+ std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
template <typename T,
typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
@@ -4085,6 +4093,33 @@ LogicalResult MaxPool2dOp::inferReturnTypeComponents(
inferredReturnShapes);
}
+LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ MaxPool2dAdaptiveOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ShapeAdaptor inputShape(adaptor.getInput().getType());
+
+ llvm::SmallVector<int64_t> kernelValues;
+ llvm::SmallVector<int64_t> strideValues;
+ llvm::SmallVector<int64_t> padValues;
+ if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
+ kernelValues) &&
+ tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+ strideValues) &&
+ tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
+ return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
+ padValues, inferredReturnShapes);
+ }
+
+ llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
+ if (inputShape.hasRank()) {
+ outputShape[0] = inputShape.getDimSize(0);
+ outputShape[3] = inputShape.getDimSize(3);
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ return success();
+}
+
LogicalResult MaxPool2dOp::verify() {
if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
/* outType = */ getOutput().getType())))
@@ -4096,6 +4131,21 @@ LogicalResult MaxPool2dOp::verify() {
return success();
}
+LogicalResult MaxPool2dAdaptiveOp::verify() {
+ if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
+ /* outType = */ getOutput().getType())))
+ return failure();
+
+ AdaptivePoolingConstShapeValues values;
+ extractAdaptivePoolingConstShapeOperands(*this, values);
+
+ if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
+ values.pad, getInput(), getOutput())))
+ return failure();
+
+ return success();
+}
+
LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
DepthwiseConv2DOp::Adaptor adaptor,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 78bf700597c3c..01c85be4f704f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -88,6 +88,14 @@ ProfileInfoDepot::populateProfileInfo(tosa::AvgPool2dAdaptiveOp op) {
return success();
}
+template <>
+LogicalResult
+ProfileInfoDepot::populateProfileInfo(tosa::MaxPool2dAdaptiveOp op) {
+ addValue(op.getInput());
+ addValue(op.getOutput());
+ return success();
+}
+
template <typename T>
LogicalResult ProfileInfoDepot::populateProfileInfoConv(T op) {
addValue(op.getInput());
@@ -288,6 +296,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Variable)
POPULATE_PROFILE_INFO_CUSTOM(VariableWrite)
POPULATE_PROFILE_INFO_CUSTOM(Dim)
+ POPULATE_PROFILE_INFO_CUSTOM(MaxPool2dAdaptive)
// For the most of tosa operators, all operands are profile/extension related
// and hence are all considered in this profile-based compilance check.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 6169003881487..8c00603d7abb4 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -359,9 +359,8 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
template <typename T>
static constexpr bool IsSupportedAdaptivePoolOp =
- std::is_same_v<T, tosa::AvgPool2dAdaptiveOp>
- // || std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>
- ;
+ std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
+ std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
template <typename T, typename std::enable_if<IsSupportedAdaptivePoolOp<T>,
int>::type = 0>
@@ -817,6 +816,7 @@ LogicalResult TosaValidation::levelCheckRanksAndSizes(Operation *op) {
CHECK_SIZES(MatMul);
CHECK_SIZES(MatmulTBlockScaled);
CHECK_SIZES(MaxPool2d);
+ CHECK_SIZES(MaxPool2dAdaptive);
CHECK_SIZES(RFFT2d);
// Scatter/Gather Operators
CHECK_SIZES(Gather);
@@ -918,6 +918,7 @@ LogicalResult TosaValidation::applyLevelCheck(Operation *op) {
failed(levelCheckConv<tosa::DepthwiseConv2DOp>(op)) ||
failed(levelCheckFFT<tosa::FFT2dOp>(op)) ||
failed(levelCheckPool<tosa::MaxPool2dOp>(op)) ||
+ failed(levelCheckAdaptivePool<tosa::MaxPool2dAdaptiveOp>(op)) ||
failed(levelCheckFFT<tosa::RFFT2dOp>(op)) ||
failed(levelCheckTransposeConv2d(op)) || failed(levelCheckResize(op)) ||
failed(levelCheckConv2DBlockScaled(op))) {
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index b3bdb02c20103..ca4d2dca0e7c9 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -929,6 +929,122 @@ func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x3
// -----
+func.func @test_maxpool2d_adaptive_kernel_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+ return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_kernel_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: kernel <= MAX_KERNEL (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+ return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_y(%arg0: tensor<1x8194x32x8xf32>) -> tensor<1x2x32x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[8193, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x8194x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x2x32x8xf32>
+ return %0 : tensor<1x2x32x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_stride_x(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 8193]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: stride <= MAX_STRIDE (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+ return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_first(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+ // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+ // This is a workaround for the above so that we can level check the padding.
+ %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[8193, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+ return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_second(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+ // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+ // This is a workaround for the above so that we can level check the padding.
+ %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 8193, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+ return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_third(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+ // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+ // This is a workaround for the above so that we can level check the padding.
+ %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 8193, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+ return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
+func.func @test_maxpool2d_adaptive_pad_forth(%arg0: tensor<1x32x8194x8xf32>) -> tensor<1x32x2x8xf32> {
+ // If the source of the kernel passed to max_pool2d_adaptive is a const_shape then pad < kernel check is applied.
+ // This is a workaround for the above so that we can level check the padding.
+ %a = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %b = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %kernel = tosa.add_shape %a, %b : (!tosa.shape<2>, !tosa.shape<2>) -> !tosa.shape<2>
+
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 8193]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ // expected-error at +1 {{'tosa.max_pool2d_adaptive' op failed level check: pad <= MAX_KERNEL (8192), got 8193}}
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad :
+ (tensor<1x32x8194x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x2x8xf32>
+ return %0 : tensor<1x32x2x8xf32>
+}
+
+// -----
+
func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
// expected-error at +1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL (8192), got 16384}}
%0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index e80d3d84a8105..b30e92c4a9621 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -253,6 +253,53 @@ func.func @test_max_pool2d_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8
return %0 : tensor<1x32x32x8xf16>
}
+// CHECK-LABEL: max_pool2d_adaptive_f32
+func.func @test_max_pool2d_adaptive_f32(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xf32>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xf32>
+ return %0 : tensor<1x32x32x8xf32>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_bf16
+func.func @test_max_pool2d_adaptive_bf16(%arg0: tensor<1x32x32x8xbf16>) -> tensor<1x32x32x8xbf16> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %pad = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %0 = tosa.max_pool2d_adaptive %arg0, %kernel, %stride, %pad : (tensor<1x32x32x8xbf16>, !tosa.shape<2>, !tosa.shape<2>, !tosa.shape<4>) -> tensor<1x32x32x8xbf16>
+ return %0 : tensor<1x32x32x8xbf16>
+}
+
+// -----
+// CHECK-LABEL: max_pool2d_adaptive_f16
+func.func @test_max_pool2d_adaptive_f16(%arg0: tensor<1x32x32x8xf16>) -> tensor<1x32x32x8xf16> {
+ %kernel = tosa.const_shape {values = dense<[1, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+ %stride = tosa.const_...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/191225
More information about the Mlir-commits
mailing list