[Mlir-commits] [mlir] 94f255c - [mlir][tosa] Add RFFT2d operation
Jacques Pienaar
llvmlistbot at llvm.org
Tue Jan 24 15:42:12 PST 2023
Author: Luke Hutton
Date: 2023-01-24T15:42:02-08:00
New Revision: 94f255c2c4d5c6733819affac5d1acb19e3f5e94
URL: https://github.com/llvm/llvm-project/commit/94f255c2c4d5c6733819affac5d1acb19e3f5e94
DIFF: https://github.com/llvm/llvm-project/commit/94f255c2c4d5c6733819affac5d1acb19e3f5e94.diff
LOG: [mlir][tosa] Add RFFT2d operation
Adds the RFFT2d TOSA operation and supporting
shape inference function.
Signed-off-by: Luke Hutton <luke.hutton at arm.com>
Change-Id: I7e49c47cdd846cdc1b187545ef76d5cda2d5d9ad
Reviewed By: jpienaar
Differential Revision: https://reviews.llvm.org/D142336
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/test/Dialect/Tosa/ops.mlir
mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index afdd8017cec58..35a099684837e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
#include "mlir/Dialect/Traits.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6609c6bdf8199..f9221662b1d7e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -270,6 +270,34 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
let hasCanonicalizer = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: rfft2d
+//===----------------------------------------------------------------------===//
+def Tosa_RFFT2dOp : Tosa_Op<"rfft2d", [
+ DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+ ["inferReturnTypeComponents"]>,
+ Pure]> {
+ let summary = "Performs RFFT2D operation on the input.";
+
+ let description = [{
+ Performs a batched 2D real-valued Fast Fourier Transform over the input where
+ the input tensor consists of real values producing complex valued output. The
+ complex output values will be split into the output_real and output_imag
+ tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only
+ calculate the first half of the final output axis. Imaginary values with
+ locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero.
+ }];
+
+ let arguments = (ins
+ Tosa_Tensor3D:$input
+ );
+
+ let results = (outs
+ Tosa_Tensor3D:$output_real,
+ Tosa_Tensor3D:$output_imag
+ );
+}
+
//===----------------------------------------------------------------------===//
// Operator: transpose_conv2d
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 338f5308e38be..e78a0b1543223 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -387,6 +387,31 @@ LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
return success();
}
+LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ShapeAdaptor inputShape = operands.getShape(0);
+
+ if (!inputShape.hasRank())
+ return failure();
+
+ llvm::SmallVector<int64_t> outputShape;
+ outputShape.resize(3, ShapedType::kDynamic);
+ outputShape[0] = inputShape.getDimSize(0);
+ outputShape[1] = inputShape.getDimSize(1);
+ int64_t inWidth = inputShape.getDimSize(2);
+
+ // Note that we can support this calculation symbolically
+ // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1]
+ if (inWidth != ShapedType::kDynamic)
+ outputShape[2] = inWidth / 2 + 1;
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ return success();
+}
+
LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 4599ca846fb71..fa8257746e875 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -72,6 +72,13 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32
return %0 : tensor<1x32x32x8xf32>
}
+// -----
+// CHECK-LABEL: rfft2d
+func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
+ %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
+ return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
+}
+
// -----
// CHECK-LABEL: transpose_conv2d
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index f6740dbd09000..c955d57119733 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1189,3 +1189,30 @@ func.func @while_test(%arg0 : tensor<i32>, %arg1 : tensor<1xi32>) -> () {
}) : (tensor<i32>, tensor<1xi32>) -> (tensor<*xi32>, tensor<*xi32>)
return
}
+
+// -----
+
+// CHECK-LABEL: @test_static_rfft2d
+func.func @test_static_rfft2d(%arg0: tensor<5x2x8xf32>) -> () {
+ // CHECK: -> (tensor<5x2x5xf32>, tensor<5x2x5xf32>)
+ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x8xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_batch_rfft2d
+func.func @test_dynamic_batch_rfft2d(%arg0 : tensor<?x2x4xf32>) -> () {
+ // CHECK: -> (tensor<?x2x3xf32>, tensor<?x2x3xf32>)
+ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x2x4xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @test_dynamic_width_rfft2d
+func.func @test_dynamic_width_rfft2d(%arg0 : tensor<5x2x?xf32>) -> () {
+ // CHECK: -> (tensor<5x2x?xf32>, tensor<5x2x?xf32>)
+ %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x2x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+ return
+}
More information about the Mlir-commits
mailing list