[Mlir-commits] [mlir] [mlir][tosa] Add verifiers for FFT2d and RFFT2d (PR #129273)

Luke Hutton llvmlistbot at llvm.org
Fri Feb 28 09:15:48 PST 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/129273

Adds checks for element types and input/output shapes.

>From c6e810d1c68cc6f95f713b70f94531ecfb2db016 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 26 Feb 2025 15:49:17 +0000
Subject: [PATCH] [mlir][tosa] Add verifiers for FFT2d and RFFT2d

Adds checks for element types and input/output shapes.

Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: Ib40928027f5b9d75306aa662c4627e3263db7de7
---
 mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td  | 13 ++-
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          | 80 ++++++++++++++++++-
 .../TosaToLinalg/tosa-to-linalg-invalid.mlir  |  9 ---
 .../TosaToLinalg/tosa-to-linalg.mlir          | 24 +++---
 mlir/test/Dialect/Tosa/invalid.mlir           | 80 +++++++++++++++++++
 mlir/test/Dialect/Tosa/level_check.mlir       | 36 ++++-----
 6 files changed, 200 insertions(+), 42 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ddfec2c9bfcd3..9e14daf0d014c 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -240,7 +240,10 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: fft2d
 //===----------------------------------------------------------------------===//
-def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
+def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d", [
+    SameOperandsAndResultElementType,
+    SameOperandsAndResultShape,
+    ResultsAreFloatLike]> {
   let summary = "Performs FFT2D operation on the input.";
 
   let description = [{
@@ -279,6 +282,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
     $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
     type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -349,7 +354,9 @@ def Tosa_MaxPool2dOp : Tosa_InferShapedTypeOp<"max_pool2d"> {
 //===----------------------------------------------------------------------===//
 // Operator: rfft2d
 //===----------------------------------------------------------------------===//
-def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
+def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d", [
+    SameOperandsAndResultElementType,
+    ResultsAreFloatLike]> {
   let summary = "Performs RFFT2D operation on the input.";
 
   let description = [{
@@ -385,6 +392,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
   let assemblyFormat = [{
     $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`
   }];
+
+  let hasVerifier = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 7b50eceb081dd..63afaed22d7ab 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -789,7 +789,7 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
   int64_t inWidth = inputShape.getDimSize(2);
 
   // Note that we can support this calculation symbolically
-  // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+  // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
   if (inWidth != ShapedType::kDynamic)
     outputShape[2] = inWidth / 2 + 1;
 
@@ -799,6 +799,57 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
+                                           const llvm::StringRef dimName) {
+  const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
+  if (!isPowerOfTwo)
+    return op->emitOpError("expected ")
+           << dimName << " to be a power of two, got " << dimSize;
+
+  return success();
+}
+
+LogicalResult tosa::RFFT2dOp::verify() {
+  const auto outputTypes = getResultTypes();
+  if (failed(verifyCompatibleShapes(outputTypes)))
+    return emitOpError("expected output shapes to match, got ") << outputTypes;
+
+  const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
+  if (!inputType)
+    return success();
+
+  const int64_t height = inputType.getDimSize(1);
+  if (!ShapedType::isDynamic(height) &&
+      failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+    return failure();
+
+  const int64_t width = inputType.getDimSize(2);
+  if (!ShapedType::isDynamic(width) &&
+      failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+    return failure();
+
+  const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
+  if (!outputType)
+    return success();
+
+  // Batch and height input/output dimensions should match
+  if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
+                                   outputType.getShape().drop_back())))
+    return emitOpError("expected batch and height dimensions of input/output "
+                       "to match, got input=")
+           << inputType << " output=" << outputType;
+
+  // Output width dimension expected to be input_width / 2 + 1
+  const int64_t outputWidth = outputType.getDimSize(2);
+  if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
+      (outputWidth - 1) * 2 != width)
+    return emitOpError(
+               "expected output width to be equal to input_width / 2 + 1, got ")
+           << outputWidth;
+
+  return success();
+}
+
 LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     FFT2dOp::Adaptor adaptor,
@@ -810,6 +861,33 @@ LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
   return success();
 }
 
+LogicalResult tosa::FFT2dOp::verify() {
+  const auto inputRealType =
+      llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
+  const auto inputImagType =
+      llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
+  if (!inputRealType || !inputImagType)
+    return success();
+
+  const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
+    return ShapedType::isDynamic(a) ? a : b;
+  };
+
+  const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
+                                            inputImagType.getDimSize(1));
+  if (!ShapedType::isDynamic(height) &&
+      failed(verifyDimIsPowerOfTwo(*this, height, "height")))
+    return failure();
+
+  const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
+                                           inputImagType.getDimSize(2));
+  if (!ShapedType::isDynamic(width) &&
+      failed(verifyDimIsPowerOfTwo(*this, width, "width")))
+    return failure();
+
+  return success();
+}
+
 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     ConcatOp::Adaptor adaptor,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index 5db3f56cf459e..71fcd4129a618 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -31,15 +31,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
 
 // -----
 
-// CHECK-LABEL: @rfft2d_with_non_float_type
-func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
-  // expected-error at +1 {{failed to legalize operation 'tosa.rfft2d'}}
-  %real, %imag = tosa.rfft2d %arg0 : (tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>)
-  return %real, %imag : tensor<1x1x1xi32>, tensor<1x1x1xi32>
-}
-
-// -----
-
 // CHECK-LABEL: @rescale_unsupported_type
 func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
   // expected-error at +1 {{failed to legalize operation 'tosa.rescale'}}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 78f2e173d7cb1..e68783f779063 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1720,20 +1720,20 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
 
 // CHECK-LABEL:   func.func @test_static_rfft2d(
-// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 1 : index
 // CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
 // CHECK:           %[[VAL_3:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_4:.*]] = arith.constant 4 : index
 // CHECK:           %[[VAL_5:.*]] = arith.constant 5 : index
-// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_6:.*]] = tensor.empty() : tensor<5x4x5xf32>
 // CHECK:           %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:           %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
+// CHECK:           %[[VAL_9:.*]] = tensor.empty() : tensor<5x4x5xf32>
 // CHECK:           %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:           %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x4x5xf32>) -> tensor<5x4x5xf32>
 // CHECK:           %[[VAL_12:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK:           %[[VAL_13:.*]] = arith.constant 4 : index
 // CHECK:           %[[VAL_14:.*]] = arith.constant 2 : index
 // CHECK:           %[[VAL_15:.*]] = arith.constant 8 : index
 // CHECK:           %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
@@ -1741,7 +1741,7 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK:           %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
 // CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
 // CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
-// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:           %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x4x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
 // CHECK:           ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
 // CHECK:             %[[VAL_25:.*]] = linalg.index 1 : index
 // CHECK:             %[[VAL_26:.*]] = linalg.index 2 : index
@@ -1766,12 +1766,12 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
 // CHECK:             %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
 // CHECK:             %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
 // CHECK:             linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
-// CHECK:           } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK:           } -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+// CHECK:           return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x4x5xf32>, tensor<5x4x5xf32>
 // CHECK:         }
-func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-  return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+func.func @test_static_rfft2d(%arg0: tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>) {
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x4x8xf32>) -> (tensor<5x4x5xf32>, tensor<5x4x5xf32>)
+  return %output_real, %output_imag : tensor<5x4x5xf32>, tensor<5x4x5xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 123c65e1b4fcd..fb7a9222947f6 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1276,3 +1276,83 @@ func.func @test_conv2d_invalid_bias_size(%arg0: tensor<1x4x4x4xf32>, %arg1: tens
     : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x4x8xf32>
   return %0 : tensor<1x4x4x8xf32>
 }
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>) {
+  // expected-error at +1 {{'tosa.fft2d' op requires the same element type for all operands and results}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf16>, tensor<1x4x8xf16>)
+  return %0, %1 : tensor<1x4x8xf16>, tensor<1x4x8xf16>
+}
+
+// -----
+
+func.func @test_fft2d_same_operands_and_result_shape(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op requires the same shape for all operands and results}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x7xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+func.func @test_fft2d_invalid_type(%arg0: tensor<1x4x8xi8>, %arg1: tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>) {
+  // expected-error at +1 {{'tosa.fft2d' op requires a floating point type}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xi8>, tensor<1x4x8xi8>) -> (tensor<1x4x8xi8>, tensor<1x4x8xi8>)
+  return %0, %1 : tensor<1x4x8xi8>, tensor<1x4x8xi8>
+}
+
+// -----
+
+func.func @test_fft2d_height_non_power_of_two(%arg0: tensor<1x5x8xf32>, %arg1: tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>) {
+  // expected-error at +1 {{'tosa.fft2d' op expected height to be a power of two, got 5}}
+  %0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x5x8xf32>, tensor<1x5x8xf32>) -> (tensor<1x5x8xf32>, tensor<1x5x8xf32>)
+  return %0, %1 : tensor<1x5x8xf32>, tensor<1x5x8xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_same_operands_and_result_element_type(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op requires the same element type for all operands and results}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+  return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_same_results_shape(%arg0: tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>) {
+  // expected-error at +1 {{'tosa.rfft2d' op expected output shapes to match, got 'tensor<1x4x6xf32>', 'tensor<1x4x5xf32>'}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf32>) -> (tensor<1x4x6xf32>, tensor<1x4x5xf32>)
+  return %0, %1 : tensor<1x4x6xf32>, tensor<1x4x5xf32>
+}
+
+// -----
+
+func.func @test_rfft2d_invalid_type(%arg0: tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op requires a floating point type}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xi16>) -> (tensor<1x4x5xi16>, tensor<1x4x5xi16>)
+  return %0, %1 : tensor<1x4x5xi16>, tensor<1x4x5xi16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_power_of_two(%arg0: tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op expected width to be a power of two, got 9}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x9xf16>) -> (tensor<1x4x5xf16>, tensor<1x4x5xf16>)
+  return %0, %1 : tensor<1x4x5xf16>, tensor<1x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_batch_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op expected batch and height dimensions of input/output to match, got input='tensor<1x4x8xf16>' output='tensor<2x4x5xf16>'}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<2x4x5xf16>, tensor<2x4x5xf16>)
+  return %0, %1 : tensor<2x4x5xf16>, tensor<2x4x5xf16>
+}
+
+// -----
+
+func.func @test_rfft2d_width_input_output_match(%arg0: tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>) {
+  // expected-error at +1 {{'tosa.rfft2d' op expected output width to be equal to input_width / 2 + 1, got 3}}
+  %0, %1 = tosa.rfft2d %arg0 {inverse = false} : (tensor<1x4x8xf16>) -> (tensor<1x4x3xf16>, tensor<1x4x3xf16>)
+  return %0, %1 : tensor<1x4x3xf16>, tensor<1x4x3xf16>
+}
diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir
index d2958efe1bb24..19584103d40c7 100644
--- a/mlir/test/Dialect/Tosa/level_check.mlir
+++ b/mlir/test/Dialect/Tosa/level_check.mlir
@@ -477,38 +477,38 @@ func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: t
 
 // -----
 
-func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
   // expected-error at +1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+  return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
 }
 
 // -----
 
-func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_real_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
   // expected-error at +1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+  return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
 }
 
 // -----
 
-func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_h(%arg0: tensor<32x16384x32xf32>, %arg1: tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) {
   // expected-error at +1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>) -> (tensor<32x16384x32xf32>, tensor<32x16384x32xf32>)
+  return %0, %1 : tensor<32x16384x32xf32>, tensor<32x16384x32xf32>
 }
 
 // -----
 
-func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) {
+func.func @test_fft2d_imag_w(%arg0: tensor<32x32x16384xf32>, %arg1: tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) {
   // expected-error at +1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}}
   %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } :
-            (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>)
-  return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32>
+            (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>) -> (tensor<32x32x16384xf32>, tensor<32x32x16384xf32>)
+  return %0, %1 : tensor<32x32x16384xf32>, tensor<32x32x16384xf32>
 }
 
 // -----
@@ -577,18 +577,18 @@ func.func @test_maxpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32
 
 // -----
 
-func.func @test_rfft2d_input_h(%arg0: tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+func.func @test_rfft2d_input_h(%arg0: tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>) {
   // expected-error at +1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL}}
-  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
-  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x16384x16xf32>) -> (tensor<13x16384x9xf32>, tensor<13x16384x9xf32>)
+  return %0, %1 : tensor<13x16384x9xf32>, tensor<13x16384x9xf32>
 }
 
 // -----
 
-func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+func.func @test_rfft2d_input_w(%arg0: tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>) {
   // expected-error at +1 {{'tosa.rfft2d' op failed level check: W <= MAX_KERNEL}}
-  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
-  return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+  %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16384xf32>) -> (tensor<13x8x8193xf32>, tensor<13x8x8193xf32>)
+  return %0, %1 : tensor<13x8x8193xf32>, tensor<13x8x8193xf32>
 }
 
 // -----



More information about the Mlir-commits mailing list