[Mlir-commits] [mlir] 3129388 - [mlir][tosa] Add FFT2d operation

Robert Suderman llvmlistbot at llvm.org
Tue Mar 14 12:06:56 PDT 2023


Author: Luke Hutton
Date: 2023-03-14T19:04:52Z
New Revision: 312938864e676c6a0ba9ba173d3ded910e5e48c9

URL: https://github.com/llvm/llvm-project/commit/312938864e676c6a0ba9ba173d3ded910e5e48c9
DIFF: https://github.com/llvm/llvm-project/commit/312938864e676c6a0ba9ba173d3ded910e5e48c9.diff

LOG: [mlir][tosa] Add FFT2d operation

Adds the FFT2d TOSA operation and supporting
shape inference function.

Signed-off-by: Luke Hutton <luke.hutton at arm.com>

Reviewed By: rsuderman, eric-k256

Differential Revision: https://reviews.llvm.org/D144784

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/test/Dialect/Tosa/ops.mlir
    mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index fc4cf7d82174c..be5720caeb0de 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -183,6 +183,36 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Operator: fft2d
+//===----------------------------------------------------------------------===//
+def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
+    DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                              ["inferReturnTypeComponents"]>,
+    Pure]> {
+  let summary = "Performs FFT2D operation on the input.";
+
+  let description = [{
+    Performs a batched complex 2D Fast Fourier Transform over the input. The
+    complex input values are constructed from the corresponding values in the
+    input_real and input_imag tensors. The resulting values in the output are
+    split into the output_real and output_imag tensors. No normalization is
+    applied on either the forward or inverse versions of the operation.
+  }];
+
+  let arguments = (ins
+    Tosa_Tensor3D:$input_real,
+    Tosa_Tensor3D:$input_imag,
+
+    BoolAttr:$inverse
+  );
+
+  let results = (outs
+    Tosa_Tensor3D:$output_real,
+    Tosa_Tensor3D:$output_imag
+  );
+}
+
 //===----------------------------------------------------------------------===//
 // Operator: fully_connected
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c3b161ac35032..d7bb6d0bddbf6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -409,6 +409,16 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
 
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+
+  return success();
+}
+
+LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
+  inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index fa8257746e875..68eca320aa46b 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -51,6 +51,13 @@ func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4
   return %2 : tensor<1x4x4x8xf32>
 }
 
+// -----
+// CHECK-LABEL: fft2d
+func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  %0, %1 = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
 // -----
 // CHECK-LABEL: fully_connected
 func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {

diff  --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index c955d57119733..94eea3b36eae2 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1216,3 +1216,21 @@ func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () {
   %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_static_fft2d
+func.func @test_static_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+  // CHECK: -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+  return %output_real, %output_imag : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_batch_fft2d
+func.func @test_dynamic_batch_fft2d(%arg0: tensor<?x4x8xf32>, %arg1: tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>) {
+  // CHECK: -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
+  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<?x4x8xf32>, tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
+  return %output_real, %output_imag : tensor<?x4x8xf32>, tensor<?x4x8xf32>
+}


        


More information about the Mlir-commits mailing list