[Mlir-commits] [mlir] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d (PR #75098)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 13:19:47 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Dmitriy Smirnov (d-smirnov)

<details>
<summary>Changes</summary>

Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise. 
@<!-- -->eric-k256 @<!-- -->GeorgeARM 

---
Full diff: https://github.com/llvm/llvm-project/pull/75098.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-2) 
- (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+24-24) 


``````````diff
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96f..f0e4a00bdc5179 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2284,8 +2284,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
 
   static Value castIndexToFloat(OpBuilder &builder, Location loc,
                                 FloatType type, Value value) {
-    auto integerVal =
-        builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+    auto integerVal = builder.create<arith::IndexCastUIOp>(
+        loc,
+        32 <= type.getIntOrFloatBitWidth() ? builder.getI32Type()
+                                           : builder.getI64Type(),
+        value);
 
     return builder.create<arith::UIToFPOp>(loc, type, integerVal);
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da68..6f9597dd399af6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1691,10 +1691,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
 // CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
 // CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1702,17 +1702,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
 // CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
 // CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1761,10 +1761,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
 // CHECK:   %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
 // CHECK:   %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1772,17 +1772,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:     outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
 // CHECK:     %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
 // CHECK:     %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/75098


More information about the Mlir-commits mailing list