[Mlir-commits] [mlir] 68ef0e9 - [mlir][tosa] Implement lowering for tosa.rfft2d

Eric Kunze llvmlistbot at llvm.org
Tue May 30 16:11:07 PDT 2023


Author: Spenser Bauman
Date: 2023-05-30T16:09:17-07:00
New Revision: 68ef0e95b20ac1bebb119977fe7c9ac08a764ebe

URL: https://github.com/llvm/llvm-project/commit/68ef0e95b20ac1bebb119977fe7c9ac08a764ebe
DIFF: https://github.com/llvm/llvm-project/commit/68ef0e95b20ac1bebb119977fe7c9ac08a764ebe.diff

LOG: [mlir][tosa] Implement lowering for tosa.rfft2d

Implement a lowering for tosa.rfft2d to linalg.generic in the
TosaToLinalg transform.

Reviewed By: eric-k256

Differential Revision: https://reviews.llvm.org/D151095

Added: 
    

Modified: 
    mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
    mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 9e0cccff6cf99..0ca05882cca74 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -12,7 +12,9 @@
 
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -20,6 +22,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
@@ -2021,6 +2024,162 @@ class TableConverter : public OpRewritePattern<tosa::TableOp> {
   }
 };
 
+struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
+  using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
+
+  static bool isRankedTensor(Type type) { return isa<RankedTensorType>(type); }
+
+  static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
+                                  OpFoldResult ofr) {
+    auto one = builder.create<arith::ConstantIndexOp>(loc, 1);
+    auto two = builder.create<arith::ConstantIndexOp>(loc, 2);
+
+    auto value = getValueOrCreateConstantIndexOp(builder, loc, ofr);
+    auto divBy2 = builder.createOrFold<arith::DivUIOp>(loc, value, two);
+    auto plusOne = builder.createOrFold<arith::AddIOp>(loc, divBy2, one);
+    return getAsOpFoldResult(plusOne);
+  }
+
+  static RankedTensorType
+  computeOutputShape(OpBuilder &builder, Location loc, Value input,
+                     llvm::SmallVectorImpl<Value> &dynamicSizes) {
+    // Get [N, H, W]
+    auto dims = linalg::getMixedDimensions(builder, loc, input);
+
+    // Set W = (W / 2) + 1 to account for the half-sized W dimension of the
+    // output tensors.
+    dims[2] = halfPlusOne(builder, loc, dims[2]);
+
+    llvm::SmallVector<int64_t, 3> staticSizes;
+    dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
+
+    auto elementType =
+        input.getType().cast<RankedTensorType>().getElementType();
+    return RankedTensorType::get(staticSizes, elementType);
+  }
+
+  static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
+                                RankedTensorType type,
+                                llvm::ArrayRef<Value> dynamicSizes) {
+    auto emptyTensor =
+        rewriter.create<tensor::EmptyOp>(loc, type, dynamicSizes);
+    auto fillValueAttr = rewriter.getZeroAttr(type.getElementType());
+    auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
+    auto filledTensor = rewriter
+                            .create<linalg::FillOp>(loc, ValueRange{fillValue},
+                                                    ValueRange{emptyTensor})
+                            .result();
+    return filledTensor;
+  }
+
+  static Value castIndexToFloat(OpBuilder &builder, Location loc,
+                                FloatType type, Value value) {
+    auto integerVal =
+        builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+
+    return builder.create<arith::UIToFPOp>(loc, type, integerVal);
+  }
+
+  static Value createLinalgIndex(OpBuilder &builder, Location loc,
+                                 FloatType type, int64_t index) {
+    auto indexVal = builder.create<linalg::IndexOp>(loc, index);
+    return castIndexToFloat(builder, loc, type, indexVal);
+  }
+
+  template <typename... Args>
+  static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
+                                                         Args... args) {
+    return {builder.getAffineDimExpr(args)...};
+  }
+
+  LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
+        !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
+      return rewriter.notifyMatchFailure(rfft2d,
+                                         "only supports ranked tensors");
+    }
+
+    auto loc = rfft2d.getLoc();
+    auto input = rfft2d.getInput();
+    auto elementType =
+        input.getType().cast<ShapedType>().getElementType().cast<FloatType>();
+
+    // Compute the output type and set of dynamic sizes
+    llvm::SmallVector<Value> dynamicSizes;
+    auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
+
+    // Iterator types for the linalg.generic implementation
+    llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
+        utils::IteratorType::parallel, utils::IteratorType::parallel,
+        utils::IteratorType::parallel, utils::IteratorType::reduction,
+        utils::IteratorType::reduction};
+
+    // Inputs/outputs to the linalg.generic implementation
+    llvm::SmallVector<Value> genericOpInputs = {input};
+    llvm::SmallVector<Value> genericOpOutputs = {
+        createZeroTensor(rewriter, loc, outputType, dynamicSizes),
+        createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
+
+    // Indexing maps for input and output tensors
+    auto indexingMaps = AffineMap::inferFromExprList(llvm::ArrayRef{
+        affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2),
+        affineDimsExpr(rewriter, 0, 1, 2)});
+
+    // Width and height dimensions of the original input.
+    auto dimH = linalg::createOrFoldDimOp(rewriter, loc, input, 1);
+    auto dimW = linalg::createOrFoldDimOp(rewriter, loc, input, 2);
+
+    // Constants and dimension sizes
+    auto twoPiAttr = rewriter.getFloatAttr(elementType, 6.283185307179586);
+    auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+    auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
+    auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
+
+    auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
+      Value valReal = args[0];
+      Value sumReal = args[1];
+      Value sumImag = args[2];
+
+      // Indices for angle computation
+      auto oy = createLinalgIndex(builder, loc, elementType, 1);
+      auto ox = createLinalgIndex(builder, loc, elementType, 2);
+      auto iy = createLinalgIndex(builder, loc, elementType, 3);
+      auto ix = createLinalgIndex(builder, loc, elementType, 4);
+
+      // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
+      auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
+      auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
+      auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
+      auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+      auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
+      auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+
+      // realComponent = valReal * cos(angle)
+      // imagComponent = valReal * sin(angle)
+      auto cosAngle = builder.create<math::CosOp>(loc, angle);
+      auto sinAngle = builder.create<math::SinOp>(loc, angle);
+      auto realComponent =
+          builder.create<arith::MulFOp>(loc, valReal, cosAngle);
+      auto imagComponent =
+          builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+
+      // outReal = sumReal + realComponent
+      // outImag = sumImag - imagComponent
+      auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
+      auto outImag = builder.create<arith::SubFOp>(loc, sumImag, imagComponent);
+
+      builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+    };
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
+        indexingMaps, iteratorTypes, buildBody);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgConversionPatterns(
@@ -2083,6 +2242,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       GatherConverter,
       RescaleConverter,
       ReverseConverter,
+      RFFT2dConverter,
       TableConverter,
       TileConverter,
       TransposeConverter>(patterns->getContext());

diff  --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9e5615e5c33f9..1f66c669bafb6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1412,3 +1412,132 @@ func.func @select_fp32(%arg0: tensor<1x1x5x5xi1>, %arg1: tensor<1x12x5x5xf32>, %
   return %0 : tensor<1x12x5x5xf32>
 }
 
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_static_rfft2d
+// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:   %[[CST_1:.*]] = arith.constant 1 : index
+// CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
+// CHECK:   %[[CST_8:.*]] = arith.constant 8 : index
+// CHECK:   %[[CST_4:.*]] = arith.constant 4 : index
+// CHECK:   %[[CST_5:.*]] = arith.constant 5 : index
+// CHECK:   %[[EMPTY_0:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:   %[[CST_ZERO:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[VAR_1:.*]] = linalg.fill ins(%[[CST_ZERO:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
+// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
+// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
+// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
+// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK:   linalg.generic {
+// CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
+// CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK:     ins(%[[ARG_0]] : tensor<5x5x8xf32>)
+// CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
+// CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
+// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
+// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
+// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
+// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
+// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
+// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
+// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
+// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
+// CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
+// CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
+// CHECK:     %[[YCOMP:.*]] = arith.divf %[[VAR_22]], %[[VAR_8]] : f32
+// CHECK:     %[[VAR_25:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32
+// CHECK:     %[[ALPHA:.*]] = arith.mulf %[[CST_PI]], %[[VAR_25]] : f32
+// CHECK:     %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32
+// CHECK:     %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32
+// CHECK:     %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32
+// CHECK:     %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32
+// CHECK:     %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32
+// CHECK:     %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32
+// CHECK:     linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32
+// CHECK:   } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
+
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
+  return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+}
+
+// -----
+
+// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+
+// CHECK-LABEL: @test_dynamic_rfft2d
+// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK:   %[[CST_0:.*]] = arith.constant 0 : index
+// CHECK:   %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[CST_0]] : tensor<?x?x?xf32>
+// CHECK:   %[[CST_1:.*]] = arith.constant 1 : index
+// CHECK:   %[[DIM_0:.*]] = tensor.dim %[[ARG_0]], %[[CST_1]] : tensor<?x?x?xf32>
+// CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
+// CHECK:   %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
+// CHECK:   %[[CST_1_2:.*]] = arith.constant 1 : index
+// CHECK:   %[[CST_2_3:.*]] = arith.constant 2 : index
+// CHECK:   %[[VAR_0:.*]] = arith.divui %[[DIM_1]], %[[CST_2_3]] : index
+// CHECK:   %[[VAR_1:.*]] = arith.addi %[[VAR_0]], %[[CST_1_2]] : index
+// CHECK:   %[[EMPTY_0:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
+// CHECK:   %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:   %[[EMPTY_1:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
+// CHECK:   %[[CST_4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[VAR_5:.*]] = linalg.fill ins(%[[CST_4]] : f32) outs(%[[EMPTY_1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:   %[[CST_1_5:.*]] = arith.constant 1 : index
+// CHECK:   %[[DIM_6:.*]] = tensor.dim %[[ARG_0]], %[[CST_1_5]] : tensor<?x?x?xf32>
+// CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
+// CHECK:   %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
+// CHECK:   %[[CST_9:.*]] = arith.constant 6.28318548 : f32
+// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
+// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
+// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
+// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK:   linalg.generic {
+// CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
+// CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
+// CHECK:     ins(%[[ARG_0]] : tensor<?x?x?xf32>)
+// CHECK:     outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
+// CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
+// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
+// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
+// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
+// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
+// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
+// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
+// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
+// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK:     %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
+// CHECK:     %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
+// CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32
+// CHECK:     %[[YCOMP:.*]] = arith.divf %[[VAR_24]], %[[VAR_9]] : f32
+// CHECK:     %[[VAR_27:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32
+// CHECK:     %[[ALPHA:.*]] = arith.mulf %[[CST_9]], %[[VAR_27]] : f32
+// CHECK:     %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32
+// CHECK:     %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32
+// CHECK:     %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32
+// CHECK:     %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32
+// CHECK:     %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32
+// CHECK:     %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32
+// CHECK:     linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32
+// CHECK:   } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+
+  %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+}


        


More information about the Mlir-commits mailing list