[Mlir-commits] [mlir] [TOSA] Fix RFFT2D verifier for width=1 (PR #130279)
Thomas Preud'homme
llvmlistbot at llvm.org
Fri Mar 7 04:46:18 PST 2025
https://github.com/RoboTux created https://github.com/llvm/llvm-project/pull/130279
Current formula assumes width is a multiple of 2 but TOSA only requires
a power of 2, which 1 is.
>From 34350d9e531a704f3b6ed14cfe92e3abe21a63d4 Mon Sep 17 00:00:00 2001
From: Thomas Preud'homme <thomas.preudhomme at arm.com>
Date: Thu, 6 Mar 2025 23:06:15 +0000
Subject: [PATCH] [TOSA] Fix RFFT2D verifier for width=1
Current formula assumes width is a multiple of 2 but TOSA only requires
a power of 2, which 1 is.
---
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 2 +-
mlir/test/Dialect/Tosa/ops.mlir | 7 +++++++
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ea4414fc1890e..f8588aa0ace0f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -888,7 +888,7 @@ LogicalResult tosa::RFFT2dOp::verify() {
// Output width dimension expected to be input_width / 2 + 1
const int64_t outputWidth = outputType.getDimSize(2);
if (!ShapedType::isDynamic(width) && !ShapedType::isDynamic(outputWidth) &&
- (outputWidth - 1) * 2 != width)
+ (outputWidth != (width / 2) + 1))
return emitOpError(
"expected output width to be equal to input_width / 2 + 1, got ")
<< outputWidth;
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 4c2cda8d9c027..600afe2abbff2 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -175,6 +175,13 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tenso
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
}
+// -----
+// CHECK-LABEL: rfft2d_width1
+func.func @test_rfft2d_width1(%arg0: tensor<1x1x1xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>) {
+ %0, %1 = tosa.rfft2d %arg0 : (tensor<1x1x1xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>)
+ return %0, %1 : tensor<1x1x1xf32>, tensor<1x1x1xf32>
+}
+
// -----
// CHECK-LABEL: rfft2d_with_local_bound
func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
More information about the Mlir-commits
mailing list