[Mlir-commits] [mlir] [mlir][tosa] Add verifiers for FFT2d and RFFT2d (PR #129273)
Luke Hutton
llvmlistbot at llvm.org
Fri Feb 28 09:15:48 PST 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/129273
Adds checks for element types and input/output shapes.
>From c6e810d1c68cc6f95f713b70f94531ecfb2db016 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 26 Feb 2025 15:49:17 +0000
Subject: [PATCH] [mlir][tosa] Add verifiers for FFT2d and RFFT2d
Adds checks for element types and input/output shapes.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: Ib40928027f5b9d75306aa662c4627e3263db7de7
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 13 ++-
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 80 ++++++++++++++++++-
.../TosaToLinalg/tosa-to-linalg-invalid.mlir | 9 ---
.../TosaToLinalg/tosa-to-linalg.mlir | 24 +++---
mlir/test/Dialect/Tosa/invalid.mlir | 80 +++++++++++++++++++
mlir/test/Dialect/Tosa/level_check.mlir | 36 ++++-----
6 files changed, 200 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..9e14daf0d014c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
//===----------------------------------------------------------------------===//
// Operator: fft2d
//===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
+def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
+ SameOperandsAndResultElementType,
+ SameOperandsAndResultShape,
+ ResultsAreFloatLike]> {
let summary = "Performs FFT2D operation on the input.";
let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
$input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
//===----------------------------------------------------------------------===//
// Operator: rfft2d
//===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
+def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
+ SameOperandsAndResultElementType,
+ ResultsAreFloatLike]> {
let summary = "Performs RFFT2D operation on the input.";
let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
let assemblyFormat = [{
$input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63afaed22d7ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
int64_t inWidth = inputShape.getDimSize(2);
// Note that we can support this calculation symbolically
- // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+ // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
if (inWidth != ShapedType::kDynamic)
outputShape[2] = inWidth / 2 + 1;
@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
return success();
}
+static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
+ const llvm::StringRef dimName) {
+ const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
+ if (!isPowerOfTwo)
+ return op->emitOpError("expected ")
+ << dimName << " to be a power of two, got " << dimSize;
+
+ return success();
+}
+
+LogicalResult tosa::RFFT2dOp::verify() {
+ const auto outputTypes = getResultTypes();
+ if (failed(verifyCompatibleShapes(outputTypes)))
+ return emitOpError("expected output shapes to match, got ") << outputTypes;
+
+ const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+ if (!inputType)
+ return success();
+
+ const int64_t height = inputType.getDimSize(1);
+ if (!ShapedType::isDynamic(height) &&
+ failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+ return failure();
+
+ const int64_t width = inputType.getDimSize(2);
+ if (!ShapedType::isDynamic(width) &&
+ failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+ return failure();
+
+ const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
+ if (!outputType)
+ return success();
+
+ // Batch and height input/output dimensions should match
+ if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
+ outputType.getShape().drop_back())))
+ return emitOpError("expected batch and height dimensions of input/output "
+ "to match, got input=")
+ << inputType << " output=" << outputType;
+
+ // Output width dimension expected to be input_width / 2 + 1
+ const int64_t outputWidth = outputType.getDimSize(2);
+ if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
+ (outputWidth - 1) * 2 != width)
+ return emitOpError(
+ "expected output width to be equal to input_width / 2 + 1, got ")
+ << outputWidth;
+
+ return success();
+}
+
LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::FFT2dOp::verify() {
+ const auto inputRealType =
+ llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
+ const auto inputImagType =
+ llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
+ if (!inputRealType || !inputImagType)
+ return success();
+
+ const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
+ return ShapedType::isDynamic(a) ? a : b;
+ };
+
+ const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
+ inputImagType.getDimSize(1));
+ if (!ShapedType::isDynamic(height) &&
+ failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+ return failure();
+
+ const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
+ inputImagType.getDimSize(2));
+ if (!ShapedType::isDynamic(width) &&
+ failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+ return failure();
+
+ return success();
+}
+
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ConcatOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..71fcd4129a618 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
// -----
-// CHECK-LABEL: @rfft2d_with_non_float_type
-func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
- // expected-error at +1 {{failed to legalize operation 'tosa.rfft2d'}}
- %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
- return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
-}
-
-// -----
-
// CHECK-LABEL: @rescale_unsupported_type
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
// expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..e68783f779063 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1720,20 +1720,20 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
// CHECK-LABEL: func.func @test_static_rfft2d(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
-// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
+// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
@@ -1741,7 +1741,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
-// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
-// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK: } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
// CHECK: }
-func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
- %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
- return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
+ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+ return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
}
// -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..fb7a9222947f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
: (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
return %0 : tensor<1x4x4x8xf32>
}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
+ // expected-error at +1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
+ return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
+}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+ // expected-error at +1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+ return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
+ // expected-error at +1 {{'tosa.fft2d' op requires a floating point type}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
+ return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
+}
+
+// -----
+
+func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
+ // expected-error at +1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
+ %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
+ return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+ // expected-error at +1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+ return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
+ // expected-error at +1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
+ return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
+ // expected-error at +1 {{'tosa.rfft2d' op requires a floating point type}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
+ return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+ // expected-error at +1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+ return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
+ // expected-error at +1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
+ return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
+ // expected-error at +1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
+ %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
+ return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d2958efe1bb24..19584103d40c7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
// -----
-func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error at +1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+ return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}
// -----
-func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error at +1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+ return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}
// -----
-func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
// expected-error at +1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+ return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
}
// -----
-func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
// expected-error at +1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
%0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
- (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
- return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+ (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+ return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
}
// -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32
// -----
-func.func @test_rfft2d_input_h(%arg0: tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
// expected-error at +1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL}}
- %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
- return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+ %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
+ return %0, %1 : tensor<13x16384x9xf32>, tensor<13x16384x9xf32>
}
// -----
-func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+func.func @test_rfft2d_input_w(%arg0: tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>) {
// expected-error at +1 {{'tosa.rfft2d' op failed level check: W <= MAX_KERNEL}}
- %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
- return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+ %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>)
+ return %0, %1 : tensor<13x8x8193xf32>, tensor<13x8x8193xf32>
}
// -----
More information about the Mlir-commits
mailing list