[Mlir-commits] [mlir] [TOSA] FFT2D/RFFT2D accuracy increased (PR #88510)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 12 06:03:56 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Dmitriy Smirnov (d-smirnov)
<details>
<summary>Changes</summary>
This PR increases accurasy of FFT2D/RFFT2D calculation by removing periodic part of sin/cos
---
Patch is 35.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88510.diff
3 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+41-24)
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp (+2-1)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+168-179)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412be..d63218f4a0420c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -2364,16 +2365,24 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
Value sumImag = args[2];
// Indices for angle computation
- auto oy = createLinalgIndex(builder, loc, elementType, 1);
- auto ox = createLinalgIndex(builder, loc, elementType, 2);
- auto iy = createLinalgIndex(builder, loc, elementType, 3);
- auto ix = createLinalgIndex(builder, loc, elementType, 4);
-
- // angle = 2 * pi() * ((iy * oy) / H + (ix * ox) / W)
- auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
- auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+ Value oy = builder.create<linalg::IndexOp>(loc, 1);
+ Value ox = builder.create<linalg::IndexOp>(loc, 2);
+ Value iy = builder.create<linalg::IndexOp>(loc, 3);
+ Value ix = builder.create<linalg::IndexOp>(loc, 4);
+
+ // float_t angle = 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix * ox) % W ) /
+ // W);
+ auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
+ auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+
+ auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
+ auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+
+ auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
+ auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
+
+ auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
+ auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
@@ -2478,22 +2487,30 @@ struct FFT2dConverter final : OpRewritePattern<FFT2dOp> {
Value sumImag = args[3];
// Indices for angle computation
- Value oy =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1);
- Value ox =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2);
- Value iy =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3);
- Value ix =
- RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4);
-
- // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W);
- auto iyXoy = builder.create<arith::MulFOp>(loc, iy, oy);
- auto ixXox = builder.create<arith::MulFOp>(loc, ix, ox);
- auto yComponent = builder.create<arith::DivFOp>(loc, iyXoy, constH);
- auto xComponent = builder.create<arith::DivFOp>(loc, ixXox, constW);
+ Value oy = builder.create<linalg::IndexOp>(loc, 1);
+ Value ox = builder.create<linalg::IndexOp>(loc, 2);
+ Value iy = builder.create<linalg::IndexOp>(loc, 3);
+ Value ix = builder.create<linalg::IndexOp>(loc, 4);
+
+ // float_t angle = sign_val * 2 * pi() * ( ( (iy * oy) % H) / H + ( (ix *
+ // ox) % W ) / W);
+ auto iyXoy = builder.create<index::MulOp>(loc, iy, oy);
+ auto ixXox = builder.create<index::MulOp>(loc, ix, ox);
+
+ auto iyRem = builder.create<index::RemUOp>(loc, iyXoy, dimH);
+ auto ixRem = builder.create<index::RemUOp>(loc, ixXox, dimW);
+
+ auto iyRemFloat =
+ RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
+ auto ixRemFloat =
+ RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
+
+ auto yComponent = builder.create<arith::DivFOp>(loc, iyRemFloat, constH);
+ auto xComponent = builder.create<arith::DivFOp>(loc, ixRemFloat, constW);
+
auto sumXY = builder.create<arith::AddFOp>(loc, yComponent, xComponent);
auto angle = builder.create<arith::MulFOp>(loc, twoPi, sumXY);
+
if (inverse.getValue()) {
angle = builder.create<arith::MulFOp>(
loc, angle,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index 687477810030d4..ad7f6cf84e5edc 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -40,7 +41,7 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
void getDependentDialects(DialectRegistry ®istry) const override {
registry
.insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
- tensor::TensorDialect, scf::SCFDialect>();
+ index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
}
void runOnOperation() override {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 1fa783f05f04ee..9e6112de20932f 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1622,138 +1622,132 @@ func.func @table8_dyn_table(%arg0: tensor<6xi8>, %arg1: tensor<?xi8>) -> () {
}
// -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_static_rfft2d
-// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+// CHECK-LABEL: func.func @test_static_rfft2d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_7]] : f32) outs(%[[VAL_6]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<5x5x5xf32>
+// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : f32) outs(%[[VAL_9]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_13:.*]] = arith.constant 5 : index
+// CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_15:.*]] = arith.constant 8 : index
+// CHECK: %[[VAL_16:.*]] = arith.constant 6.28318548 : f32
+// CHECK: %[[VAL_17:.*]] = arith.index_castui %[[VAL_13]] : index to i32
+// CHECK: %[[VAL_18:.*]] = arith.uitofp %[[VAL_17]] : i32 to f32
+// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32
+// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32
+// CHECK: %[[VAL_21:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<5x5x8xf32>) outs(%[[VAL_8]], %[[VAL_11]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
+// CHECK: ^bb0(%[[VAL_22:.*]]: f32, %[[VAL_23:.*]]: f32, %[[VAL_24:.*]]: f32):
+// CHECK: %[[VAL_25:.*]] = linalg.index 1 : index
+// CHECK: %[[VAL_26:.*]] = linalg.index 2 : index
+// CHECK: %[[VAL_27:.*]] = linalg.index 3 : index
+// CHECK: %[[VAL_28:.*]] = linalg.index 4 : index
+// CHECK: %[[VAL_29:.*]] = index.mul %[[VAL_27]], %[[VAL_25]]
+// CHECK: %[[VAL_30:.*]] = index.mul %[[VAL_28]], %[[VAL_26]]
+// CHECK: %[[VAL_31:.*]] = index.remu %[[VAL_29]], %[[VAL_13]]
+// CHECK: %[[VAL_32:.*]] = index.remu %[[VAL_30]], %[[VAL_15]]
+// CHECK: %[[VAL_33:.*]] = arith.index_castui %[[VAL_31]] : index to i32
+// CHECK: %[[VAL_34:.*]] = arith.uitofp %[[VAL_33]] : i32 to f32
+// CHECK: %[[VAL_35:.*]] = arith.index_castui %[[VAL_32]] : index to i32
+// CHECK: %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32
+// CHECK: %[[VAL_37:.*]] = arith.divf %[[VAL_34]], %[[VAL_18]] : f32
+// CHECK: %[[VAL_38:.*]] = arith.divf %[[VAL_36]], %[[VAL_20]] : f32
+// CHECK: %[[VAL_39:.*]] = arith.addf %[[VAL_37]], %[[VAL_38]] : f32
+// CHECK: %[[VAL_40:.*]] = arith.mulf %[[VAL_16]], %[[VAL_39]] : f32
+// CHECK: %[[VAL_41:.*]] = math.cos %[[VAL_40]] : f32
+// CHECK: %[[VAL_42:.*]] = math.sin %[[VAL_40]] : f32
+// CHECK: %[[VAL_43:.*]] = arith.mulf %[[VAL_22]], %[[VAL_41]] : f32
+// CHECK: %[[VAL_44:.*]] = arith.mulf %[[VAL_22]], %[[VAL_42]] : f32
+// CHECK: %[[VAL_45:.*]] = arith.addf %[[VAL_23]], %[[VAL_43]] : f32
+// CHECK: %[[VAL_46:.*]] = arith.subf %[[VAL_24]], %[[VAL_44]] : f32
+// CHECK: linalg.yield %[[VAL_45]], %[[VAL_46]] : f32, f32
+// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
+// CHECK: return %[[VAL_47:.*]]#0, %[[VAL_47]]#1 : tensor<5x5x5xf32>, tensor<5x5x5xf32>
+// CHECK: }
func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-// CHECK: %[[CST_1:.*]] = arith.constant 1 : index
-// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
-// CHECK: %[[CST_8:.*]] = arith.constant 8 : index
-// CHECK: %[[CST_4:.*]] = arith.constant 4 : index
-// CHECK: %[[CST_5:.*]] = arith.constant 5 : index
-// CHECK: %[[EMPTY_0:.*]] = tensor.empty() : tensor<5x5x5xf32>
-// CHECK: %[[CST_ZERO:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR_1:.*]] = linalg.fill ins(%[[CST_ZERO:.*]] : f32) outs(%[[EMPTY_0:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
-// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
-// CHECK: %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
-// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
-// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
-// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
-// CHECK: linalg.generic {
-// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
-// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
-// CHECK: ins(%[[ARG_0]] : tensor<5x5x8xf32>)
-// CHECK: outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
-// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
-// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
-// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
-// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
-// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
-// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
-// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
-// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
-// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
-// CHECK: %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
-// CHECK: %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
-// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
-// CHECK: %[[YCOMP:.*]] = arith.divf %[[VAR_22]], %[[VAR_8]] : f32
-// CHECK: %[[VAR_25:.*]] = arith.addf %[[XCOMP]], %[[YCOMP]] : f32
-// CHECK: %[[ALPHA:.*]] = arith.mulf %[[CST_PI]], %[[VAR_25]] : f32
-// CHECK: %[[COS_ALPHA:.*]] = math.cos %[[ALPHA]] : f32
-// CHECK: %[[SIN_ALPHA:.*]] = math.sin %[[ALPHA]] : f32
-// CHECK: %[[REAL_CONTRIB:.*]] = arith.mulf %[[IN]], %[[COS_ALPHA]] : f32
-// CHECK: %[[IMAG_CONTRIB:.*]] = arith.mulf %[[IN]], %[[SIN_ALPHA]] : f32
-// CHECK: %[[OUT_REAL:.*]] = arith.addf %[[OUT_0]], %[[REAL_CONTRIB]] : f32
-// CHECK: %[[OUT_IMAG:.*]] = arith.subf %[[OUT_1]], %[[IMAG_CONTRIB]] : f32
-// CHECK: linalg.yield %[[OUT_REAL]], %[[OUT_IMAG]] : f32, f32
-// CHECK: } -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
-
%output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, tensor<5x5x5xf32>)
return %output_real, %output_imag : tensor<5x5x5xf32>, tensor<5x5x5xf32>
}
// -----
+// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
+// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
-
-// CHECK-LABEL: @test_dynamic_rfft2d
-// CHECK-SAME: (%[[ARG_0:[0-9a-zA-Z_]*]]:
+// CHECK-LABEL: func.func @test_dynamic_rfft2d(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_4:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_5:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_9:.*]] = arith.divui %[[VAL_6]], %[[VAL_8]] : index
+// CHECK: %[[VAL_10:.*]] = arith.addi %[[VAL_9]], %[[VAL_7]] : index
+// CHECK: %[[VAL_11:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
+// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[VAL_14:.*]] = tensor.empty(%[[VAL_2]], %[[VAL_4]], %[[VAL_10]]) : tensor<?x?x?xf32>
+// CHECK: %[[VAL_15:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_16:.*]] = linalg.fill ins(%[[VAL_15]] : f32) outs(%[[VAL_14]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK: %[[VAL_17:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_18:.*]] = tensor.dim %[[VAL_0]], %[[VAL_17]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_19:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_21:.*]] = arith.constant 6.28318548 : f32
+// CHECK: %[[VAL_22:.*]] = arith.index_castui %[[VAL_18]] : index to i32
+// CHECK: %[[VAL_23:.*]] = arith.uitofp %[[VAL_22]] : i32 to f32
+// CHECK: %[[VAL_24:.*]] = arith.index_castui %[[VAL_20]] : index to i32
+// CHECK: %[[VAL_25:.*]] = arith.uitofp %[[VAL_24]] : i32 to f32
+// CHECK: %[[VAL_26:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]] : tensor<?x?x?xf32>) outs(%[[VAL_13]], %[[VAL_16]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
+// CHECK: ^bb0(%[[VAL_27:.*]]: f32, %[[VAL_28:.*]]: f32, %[[VAL_29:.*]]: f32):
+// CHECK: %[[VAL_30:.*]] = linalg.index 1 : index
+// CHECK: %[[VAL_31:.*]] = linalg.index 2 : index
+// CHECK: %[[VAL_32:.*]] = linalg.index 3 : index
+// CHECK: %[[VAL_33:.*]] = linalg.index 4 : index
+// CHECK: %[[VAL_34:.*]] = index.mul %[[VAL_32]], %[[VAL_30]]
+// CHECK: %[[VAL_35:.*]] = index.mul %[[VAL_33]], %[[VAL_31]]
+// CHECK: %[[VAL_36:.*]] = index.remu %[[VAL_34]], %[[VAL_18]]
+// CHECK: %[[VAL_37:.*]] = index.remu %[[VAL_35]], %[[VAL_20]]
+// CHECK: %[[VAL_38:.*]] = arith.index_castui %[[VAL_36]] : index to i32
+// CHECK: %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32
+// CHECK: %[[VAL_40:.*]] = arith.index_castui %[[VAL_37]] : index to i32
+// CHECK: %[[VAL_41:.*]] = arith.uitofp %[[VAL_40]] : i32 to f32
+// CHECK: %[[VAL_42:.*]] = arith.divf %[[VAL_39]], %[[VAL_23]] : f32
+// CHECK: %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_25]] : f32
+// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32
+// CHECK: %[[VAL_45:.*]] = arith.mulf %[[VAL_21]], %[[VAL_44]] : f32
+// CHECK: %[[VAL_46:.*]] = math.cos %[[VAL_45]] : f32
+// CHECK: %[[VAL_47:.*]] = math.sin %[[VAL_45]] : f32
+// CHECK: %[[VAL_48:.*]] = arith.mulf %[[VAL_27]], %[[VAL_46]] : f32
+// CHECK: %[[VAL_49:.*]] = arith.mulf %[[VAL_27]], %[[VAL_47]] : f32
+// CHECK: %[[VAL_50:.*]] = arith.addf %[[VAL_28]], %[[VAL_48]] : f32
+// CHECK: %[[VAL_51:.*]] = arith.subf %[[VAL_29]], %[[VAL_49]] : f32
+// CHECK: linalg.yield %[[VAL_50]], %[[VAL_51]] : f32, f32
+// CHECK: } -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
+// CHECK: return %[[VAL_52:.*]]#0, %[[VAL_52]]#1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
+// CHECK: }
func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
-// CHECK: %[[CST_0:.*]] = arith.constant 0 : index
-// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[CST_0]] : tensor<?x?x?xf32>
-// CHECK: %[[CST_1:.*]] = arith.constant 1 : index
-// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_0]], %[[CST_1]] : tensor<?x?x?xf32>
-// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
-// CHECK: %[[DIM_1:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
-// CHECK: %[[CST_1_2:.*]] = arith.constant 1 : index
-// CHECK: %[[CST_2_3:.*]] = arith.constant 2 : index
-// CHECK: %[[VAR_0:.*]] = arith.divui %[[DIM_1]], %[[CST_2_3]] : index
-// CHECK: %[[VAR_1:.*]] = arith.addi %[[VAR_0]], %[[CST_1_2]] : index
-// CHECK: %[[EMPTY_0:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
-// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_0]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK: %[[EMPTY_1:.*]] = tensor.empty(%[[DIM]], %[[DIM_0]], %[[VAR_1]]) : tensor<?x?x?xf32>
-// CHECK: %[[CST_4:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAR_5:.*]] = linalg.fill ins(%[[CST_4]] : f32) outs(%[[EMPTY_1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
-// CHECK: %[...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/88510
More information about the Mlir-commits
mailing list