[Mlir-commits] [mlir] [TOSA] FFT2D operator (PR #77005)

Dmitriy Smirnov llvmlistbot at llvm.org
Wed Jan 10 14:37:38 PST 2024


https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/77005

>From b3569dc0864d8ca9a904843945b02a93c63e7cbd Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Thu, 14 Dec 2023 11:26:53 +0000
Subject: [PATCH 1/2] [TOSA] FFT2D operator

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 129 +++++++++++++++++
 .../TosaToLinalg/tosa-to-linalg.mlir          | 134 ++++++++++++++++++
 2 files changed, 263 insertions(+)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 678081837b8138..f0f01dfad79240 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2344,6 +2344,134 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
   }
 };
 
+struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
+  using OpRewritePattern<FFT2dOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FFT2dOp fft2d,
+                                PatternRewriter &rewriter) const override {
+    if (!llvm::all_of(fft2d->getOperandTypes(),
+                      RFFT2dConverter::isRankedTensor) ||
+        !llvm::all_of(fft2d->getResultTypes(),
+                      RFFT2dConverter::isRankedTensor)) {
+      return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
+    }
+
+    auto loc = fft2d.getLoc();
+    auto input_real = fft2d.getInputReal();
+    auto input_imag = fft2d.getInputImag();
+    auto inverse = fft2d.getInverseAttr();
+
+    auto real_el_ty = input_real.getType()
+                          .cast<ShapedType>()
+                          .getElementType()
+                          .cast<FloatType>();
+    auto imag_el_ty = input_imag.getType()
+                          .cast<ShapedType>()
+                          .getElementType()
+                          .cast<FloatType>();
+
+    assert(real_el_ty == imag_el_ty);
+
+    // Compute the output type and set of dynamic sizes
+    llvm::SmallVector<Value> dynamicSizes;
+
+    // Get [N, H, W]
+    auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
+
+    llvm::SmallVector<int64_t, 3> staticSizes;
+    dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
+
+    auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
+
+    // Iterator types for the linalg.generic implementation
+    llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
+        utils::IteratorType::parallel, utils::IteratorType::parallel,
+        utils::IteratorType::parallel, utils::IteratorType::reduction,
+        utils::IteratorType::reduction};
+
+    // Inputs/outputs to the linalg.generic implementation
+    llvm::SmallVector<Value> genericOpInputs = {input_real, input_imag};
+    llvm::SmallVector<Value> genericOpOutputs = {
+        RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
+                                          dynamicSizes),
+        RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
+                                          dynamicSizes)};
+
+    // Indexing maps for input and output tensors
+    auto indexingMaps = AffineMap::inferFromExprList(
+        llvm::ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
+                       RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
+                       RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
+                       RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)});
+
+    // Width and height dimensions of the original input.
+    auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
+    auto dimW = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 2);
+
+    // Constants and dimension sizes
+    auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
+    auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
+    auto constH =
+        RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
+    auto constW =
+        RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
+
+    auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
+      Value valReal = args[0];
+      Value valImag = args[1];
+      Value sumReal = args[2];
+      Value sumImag = args[3];
+
+      // Indices for angle computation
+      auto oy = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
+      auto ox = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
+      auto iy = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
+      auto ix = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
+
+      // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
+      auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
+      auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
+      auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
+      auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+      auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
+      auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+      if (inverse.getValue()) {
+        angle = builder.create<arith::MulFOp>(
+            loc, angle,
+            rewriter.create<arith::ConstantOp>(
+                loc, rewriter.getFloatAttr(real_el_ty, -1.0)));
+      }
+
+      // realComponent = val_real * cos(a) + val_imag * sin(a);
+      // imagComponent = -val_real * sin(a) + val_imag * cos(a);
+      auto cosAngle = builder.create<math::CosOp>(loc, angle);
+      auto sinAngle = builder.create<math::SinOp>(loc, angle);
+
+      auto rcos = builder.create<arith::MulFOp>(loc, valReal, cosAngle);
+      auto rsin = builder.create<arith::MulFOp>(loc, valImag, sinAngle);
+      auto realComponent = builder.create<arith::AddFOp>(loc, rcos, rsin);
+
+      auto icos = builder.create<arith::MulFOp>(loc, valImag, cosAngle);
+      auto isin = builder.create<arith::MulFOp>(loc, valReal, sinAngle);
+
+      auto imagComponent = builder.create<arith::SubFOp>(loc, icos, isin);
+
+      // outReal = sumReal + realComponent
+      // outImag = sumImag - imagComponent
+      auto outReal = builder.create<arith::AddFOp>(loc, sumReal, realComponent);
+      auto outImag = builder.create<arith::AddFOp>(loc, sumImag, imagComponent);
+
+      builder.create<linalg::YieldOp>(loc, ValueRange{outReal, outImag});
+    };
+
+    rewriter.replaceOpWithNewOp<linalg::GenericOp>(
+        fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
+        indexingMaps, iteratorTypes, buildBody);
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tosa::populateTosaToLinalgConversionPatterns(
@@ -2407,6 +2535,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
       RescaleConverter,
       ReverseConverter,
       RFFT2dConverter,
+      FFT2dConverter,
       TableConverter,
       TileConverter>(patterns->getContext());
   // clang-format on
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3931e454da2e22..09b86430bcb2c4 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1739,3 +1739,137 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
   %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
   return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
 }
+
+// -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @test_static_fft2d(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<8x8x8xf32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
+// CHECK:           %[[VAL_2:.*]] = tensor.empty() : tensor<8x8x8xf32>
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_2]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
+// CHECK:           %[[VAL_5:.*]] = tensor.empty() : tensor<8x8x8xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_5]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32>
+// CHECK:           %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_10:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_11:.*]] = arith.constant 8 : index
+// CHECK:           %[[VAL_12:.*]] = arith.constant 6.28318548 : f32
+// CHECK:           %[[VAL_13:.*]] = arith.index_castui %[[VAL_9]] : index to i32
+// CHECK:           %[[VAL_14:.*]] = arith.uitofp %[[VAL_13]] : i32 to f32
+// CHECK:           %[[VAL_15:.*]] = arith.index_castui %[[VAL_11]] : index to i32
+// CHECK:           %[[VAL_16:.*]] = arith.uitofp %[[VAL_15]] : i32 to f32
+// CHECK:           %[[VAL_17:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) outs(%[[VAL_4]], %[[VAL_7]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
+// CHECK:           ^bb0(%[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32):
+// CHECK:             %[[VAL_22:.*]] = linalg.index 1 : index
+// CHECK:             %[[VAL_23:.*]] = arith.index_castui %[[VAL_22]] : index to i32
+// CHECK:             %[[VAL_24:.*]] = arith.uitofp %[[VAL_23]] : i32 to f32
+// CHECK:             %[[VAL_25:.*]] = linalg.index 2 : index
+// CHECK:             %[[VAL_26:.*]] = arith.index_castui %[[VAL_25]] : index to i32
+// CHECK:             %[[VAL_27:.*]] = arith.uitofp %[[VAL_26]] : i32 to f32
+// CHECK:             %[[VAL_28:.*]] = linalg.index 3 : index
+// CHECK:             %[[VAL_29:.*]] = arith.index_castui %[[VAL_28]] : index to i32
+// CHECK:             %[[VAL_30:.*]] = arith.uitofp %[[VAL_29]] : i32 to f32
+// CHECK:             %[[VAL_31:.*]] = linalg.index 4 : index
+// CHECK:             %[[VAL_32:.*]] = arith.index_castui %[[VAL_31]] : index to i32
+// CHECK:             %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32
+// CHECK:             %[[VAL_34:.*]] = arith.mulf %[[VAL_30]], %[[VAL_24]] : f32
+// CHECK:             %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f32
+// CHECK:             %[[VAL_36:.*]] = arith.divf %[[VAL_34]], %[[VAL_14]] : f32
+// CHECK:             %[[VAL_37:.*]] = arith.divf %[[VAL_35]], %[[VAL_16]] : f32
+// CHECK:             %[[VAL_38:.*]] = arith.addf %[[VAL_36]], %[[VAL_37]] : f32
+// CHECK:             %[[VAL_39:.*]] = arith.mulf %[[VAL_12]], %[[VAL_38]] : f32
+// CHECK:             %[[VAL_40:.*]] = math.cos %[[VAL_39]] : f32
+// CHECK:             %[[VAL_41:.*]] = math.sin %[[VAL_39]] : f32
+// CHECK:             %[[VAL_42:.*]] = arith.mulf %[[VAL_18]], %[[VAL_40]] : f32
+// CHECK:             %[[VAL_43:.*]] = arith.mulf %[[VAL_19]], %[[VAL_41]] : f32
+// CHECK:             %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
+// CHECK:             %[[VAL_45:.*]] = arith.mulf %[[VAL_19]], %[[VAL_40]] : f32
+// CHECK:             %[[VAL_46:.*]] = arith.mulf %[[VAL_18]], %[[VAL_41]] : f32
+// CHECK:             %[[VAL_47:.*]] = arith.subf %[[VAL_45]], %[[VAL_46]] : f32
+// CHECK:             %[[VAL_48:.*]] = arith.addf %[[VAL_20]], %[[VAL_44]] : f32
+// CHECK:             %[[VAL_49:.*]] = arith.addf %[[VAL_21]], %[[VAL_47]] : f32
+// CHECK:             linalg.yield %[[VAL_48]], %[[VAL_49]] : f32, f32
+// CHECK:           } -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>)
+// CHECK:           return %[[VAL_50:.*]]#0, %[[VAL_50]]#1 : tensor<8x8x8xf32>, tensor<8x8x8xf32>
+// CHECK:         }
+func.func @test_static_fft2d(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) {
+  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse=false} : (tensor<8x8x8xf32>, tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>)
+  return %output_real, %output_imag : tensor<8x8x8xf32>, tensor<8x8x8xf32>
+}
+
+// -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+
+// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
+
+// CHECK-LABEL:   func.func @test_dynamic_fft2d(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                                  %[[VAL_1:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_8]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:           %[[VAL_11:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_15:.*]] = tensor.dim %[[VAL_0]], %[[VAL_14]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_18:.*]] = arith.constant 6.28318548 : f32
+// CHECK:           %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
+// CHECK:           %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
+// CHECK:           %[[VAL_21:.*]] = arith.index_castui %[[VAL_17]] : index to i32
+// CHECK:           %[[VAL_22:.*]] = arith.uitofp %[[VAL_21]] : i32 to f32
+// CHECK:           %[[VAL_23:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_3]], #[[$ATTR_3]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[VAL_10]], %[[VAL_13]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK:           ^bb0(%[[VAL_24:.*]]: f32, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32):
+// CHECK:             %[[VAL_28:.*]] = linalg.index 1 : index
+// CHECK:             %[[VAL_29:.*]] = arith.index_castui %[[VAL_28]] : index to i32
+// CHECK:             %[[VAL_30:.*]] = arith.uitofp %[[VAL_29]] : i32 to f32
+// CHECK:             %[[VAL_31:.*]] = linalg.index 2 : index
+// CHECK:             %[[VAL_32:.*]] = arith.index_castui %[[VAL_31]] : index to i32
+// CHECK:             %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32
+// CHECK:             %[[VAL_34:.*]] = linalg.index 3 : index
+// CHECK:             %[[VAL_35:.*]] = arith.index_castui %[[VAL_34]] : index to i32
+// CHECK:             %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32
+// CHECK:             %[[VAL_37:.*]] = linalg.index 4 : index
+// CHECK:             %[[VAL_38:.*]] = arith.index_castui %[[VAL_37]] : index to i32
+// CHECK:             %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
+// CHECK:             %[[VAL_40:.*]] = arith.mulf %[[VAL_36]], %[[VAL_30]] : f32
+// CHECK:             %[[VAL_41:.*]] = arith.mulf %[[VAL_39]], %[[VAL_33]] : f32
+// CHECK:             %[[VAL_42:.*]] = arith.divf %[[VAL_40]], %[[VAL_20]] : f32
+// CHECK:             %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_22]] : f32
+// CHECK:             %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
+// CHECK:             %[[VAL_45:.*]] = arith.mulf %[[VAL_18]], %[[VAL_44]] : f32
+// CHECK:             %[[VAL_46:.*]] = arith.constant -1.000000e+00 : f32
+// CHECK:             %[[VAL_47:.*]] = arith.mulf %[[VAL_45]], %[[VAL_46]] : f32
+// CHECK:             %[[VAL_48:.*]] = math.cos %[[VAL_47]] : f32
+// CHECK:             %[[VAL_49:.*]] = math.sin %[[VAL_47]] : f32
+// CHECK:             %[[VAL_50:.*]] = arith.mulf %[[VAL_24]], %[[VAL_48]] : f32
+// CHECK:             %[[VAL_51:.*]] = arith.mulf %[[VAL_25]], %[[VAL_49]] : f32
+// CHECK:             %[[VAL_52:.*]] = arith.addf %[[VAL_50]], %[[VAL_51]] : f32
+// CHECK:             %[[VAL_53:.*]] = arith.mulf %[[VAL_25]], %[[VAL_48]] : f32
+// CHECK:             %[[VAL_54:.*]] = arith.mulf %[[VAL_24]], %[[VAL_49]] : f32
+// CHECK:             %[[VAL_55:.*]] = arith.subf %[[VAL_53]], %[[VAL_54]] : f32
+// CHECK:             %[[VAL_56:.*]] = arith.addf %[[VAL_26]], %[[VAL_52]] : f32
+// CHECK:             %[[VAL_57:.*]] = arith.addf %[[VAL_27]], %[[VAL_55]] : f32
+// CHECK:             linalg.yield %[[VAL_56]], %[[VAL_57]] : f32, f32
+// CHECK:           } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK:           return %[[VAL_58:.*]]#0, %[[VAL_58]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+// CHECK:         }
+func.func @test_dynamic_fft2d(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+  %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+  return %output_real, %output_imag : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+}

>From d2523431a49fc8ff58373ec974f6c88a72ac5a7b Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Wed, 10 Jan 2024 22:36:50 +0000
Subject: [PATCH 2/2] Addressed comments

---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  | 61 ++++++++++---------
 1 file changed, 31 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index f0f01dfad79240..1e94dfd7feb94e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2344,8 +2344,8 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
   }
 };
 
-struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
-  using OpRewritePattern<FFT2dOp>::OpRewritePattern;
+struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(FFT2dOp fft2d,
                                 PatternRewriter &rewriter) const override {
@@ -2356,42 +2356,39 @@ struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
       return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors");
     }
 
-    auto loc = fft2d.getLoc();
-    auto input_real = fft2d.getInputReal();
-    auto input_imag = fft2d.getInputImag();
-    auto inverse = fft2d.getInverseAttr();
+    Location loc = fft2d.getLoc();
+    Value input_real = fft2d.getInputReal();
+    Value input_imag = fft2d.getInputImag();
+    BoolAttr inverse = fft2d.getInverseAttr();
 
-    auto real_el_ty = input_real.getType()
-                          .cast<ShapedType>()
-                          .getElementType()
-                          .cast<FloatType>();
-    auto imag_el_ty = input_imag.getType()
-                          .cast<ShapedType>()
-                          .getElementType()
-                          .cast<FloatType>();
+    auto real_el_ty = cast<FloatType>(
+        cast<ShapedType>(input_real.getType()).getElementType());
+    auto imag_el_ty = cast<FloatType>(
+        cast<ShapedType>(input_imag.getType()).getElementType());
 
     assert(real_el_ty == imag_el_ty);
 
     // Compute the output type and set of dynamic sizes
-    llvm::SmallVector<Value> dynamicSizes;
+    SmallVector<Value> dynamicSizes;
 
     // Get [N, H, W]
-    auto dims = tensor::getMixedSizes(rewriter, loc, input_real);
+    ArrayRef<OpFoldResult> dims =
+        tensor::getMixedSizes(rewriter, loc, input_real);
 
-    llvm::SmallVector<int64_t, 3> staticSizes;
+    SmallVector<int64_t, 3> staticSizes;
     dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes);
 
     auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
 
     // Iterator types for the linalg.generic implementation
-    llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
+    SmallVector<utils::IteratorType, 5> iteratorTypes = {
         utils::IteratorType::parallel, utils::IteratorType::parallel,
         utils::IteratorType::parallel, utils::IteratorType::reduction,
         utils::IteratorType::reduction};
 
     // Inputs/outputs to the linalg.generic implementation
-    llvm::SmallVector<Value> genericOpInputs = {input_real, input_imag};
-    llvm::SmallVector<Value> genericOpOutputs = {
+    SmallVector<Value> genericOpInputs = {input_real, input_imag};
+    SmallVector<Value> genericOpOutputs = {
         RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
                                           dynamicSizes),
         RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
@@ -2399,10 +2396,10 @@ struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
 
     // Indexing maps for input and output tensors
     auto indexingMaps = AffineMap::inferFromExprList(
-        llvm::ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
-                       RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
-                       RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
-                       RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)});
+        ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
+                 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
+                 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
+                 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)});
 
     // Width and height dimensions of the original input.
     auto dimH = rewriter.createOrFold<tensor::DimOp>(loc, input_real, 1);
@@ -2411,9 +2408,9 @@ struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
     // Constants and dimension sizes
     auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586);
     auto twoPi = rewriter.create<arith::ConstantOp>(loc, twoPiAttr);
-    auto constH =
+    Value constH =
         RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
-    auto constW =
+    Value constW =
         RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
 
     auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) {
@@ -2423,10 +2420,14 @@ struct FFT2dConverter final : public OpRewritePattern<FFT2dOp> {
       Value sumImag = args[3];
 
       // Indices for angle computation
-      auto oy = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
-      auto ox = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
-      auto iy = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
-      auto ix = RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
+      Value oy =
+          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
+      Value ox =
+          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
+      Value iy =
+          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
+      Value ix =
+          RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
 
       // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
       auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);



More information about the Mlir-commits mailing list