[Mlir-commits] [mlir] 3129388 - [mlir][tosa] Add FFT2d operation
Robert Suderman
llvmlistbot at llvm.org
Tue Mar 14 12:06:56 PDT 2023
Author: Luke Hutton
Date: 2023-03-14T19:04:52Z
New Revision: 312938864e676c6a0ba9ba173d3ded910e5e48c9
URL: https://github.com/llvm/llvm-project/commit/312938864e676c6a0ba9ba173d3ded910e5e48c9
DIFF: https://github.com/llvm/llvm-project/commit/312938864e676c6a0ba9ba173d3ded910e5e48c9.diff
LOG: [mlir][tosa] Add FFT2d operation
Adds the FFT2d TOSA operation and supporting
shape inference function.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Reviewed By: rsuderman, eric-k256
Differential Revision: https://reviews.llvm.org/D144784
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index fc4cf7d82174c..be5720caeb0de 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -183,6 +183,36 @@ def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: fft2d
+//===----------------------------------------------------------------------===//
+def Tosa_FFT2dOp : Tosa_Op<"fft2d", [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
+ Pure]> {
+ let summary = "Performs FFT2D operation on the input.";
+
+ let description = [{
+ Performs a batched complex 2D Fast Fourier Transform over the input. The
+ complex input values are constructed from the corresponding values in the
+ input_real and input_imag tensors. The resulting values in the output are
+ split into the output_real and output_imag tensors. No normalization is
+ applied on either the forward or inverse versions of the operation.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor3D:$input_real,
+ Tosa_Tensor3D:$input_imag,
+
+ BoolAttr:$inverse
+ );
+
+ let results = (outs
+ Tosa_Tensor3D:$output_real,
+ Tosa_Tensor3D:$output_imag
+ );
+}
+
//===----------------------------------------------------------------------===//
// Operator: fully_connected
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c3b161ac35032..d7bb6d0bddbf6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -409,6 +409,16 @@ LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+
+ return success();
+}
+
+LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(0)));
+ inferredReturnShapes.push_back(ShapedTypeComponents(operands.getShape(1)));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index fa8257746e875..68eca320aa46b 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -51,6 +51,13 @@ func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4
return %2 : tensor<1x4x4x8xf32>
}
+// -----
+// CHECK-LABEL: fft2d
+func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+ %0, %1 = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+ return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
// -----
// CHECK-LABEL: fully_connected
func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index c955d57119733..94eea3b36eae2 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1216,3 +1216,21 @@ func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () {
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
return
}
+
+// -----
+
+// CHECK-LABEL: @test_static_fft2d
+func.func @test_static_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
+ // CHECK: -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+ %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
+ return %output_real, %output_imag : tensor<1x4x8xf32>, tensor<1x4x8xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_batch_fft2d
+func.func @test_dynamic_batch_fft2d(%arg0: tensor<?x4x8xf32>, %arg1: tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>) {
+ // CHECK: -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
+ %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = false} : (tensor<?x4x8xf32>, tensor<?x4x8xf32>) -> (tensor<?x4x8xf32>, tensor<?x4x8xf32>)
+ return %output_real, %output_imag : tensor<?x4x8xf32>, tensor<?x4x8xf32>
+}
More information about the Mlir-commits
mailing list