[Mlir-commits] [mlir] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d (PR #75098)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 11 13:19:47 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Dmitriy Smirnov (d-smirnov)
<details>
<summary>Changes</summary>
Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise.
@<!-- -->eric-k256 @<!-- -->GeorgeARM
---
Full diff: https://github.com/llvm/llvm-project/pull/75098.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-2)
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+24-24)
``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96f..f0e4a00bdc5179 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2284,8 +2284,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
static Value castIndexToFloat(OpBuilder &builder, Location loc,
FloatType type, Value value) {
- auto integerVal =
- builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+ auto integerVal = builder.create<arith::IndexCastUIOp>(
+ loc,
+ 32 <= type.getIntOrFloatBitWidth() ? builder.getI32Type()
+ : builder.getI64Type(),
+ value);
return builder.create<arith::UIToFPOp>(loc, type, integerVal);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da68..6f9597dd399af6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1691,10 +1691,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
// CHECK: %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
// CHECK: linalg.generic {
// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1702,17 +1702,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
// CHECK: outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
// CHECK: %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
// CHECK: %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1761,10 +1761,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
// CHECK: %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
// CHECK: linalg.generic {
// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1772,17 +1772,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
// CHECK: outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
// CHECK: %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
// CHECK: %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32
``````````
</details>
https://github.com/llvm/llvm-project/pull/75098
More information about the Mlir-commits
mailing list