[Mlir-commits] [mlir] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d (PR #75098)

Dmitriy Smirnov llvmlistbot at llvm.org
Mon Dec 11 13:19:20 PST 2023


https://github.com/d-smirnov created https://github.com/llvm/llvm-project/pull/75098

Lowering of rfft2d to linalg now uses index to i32 cast if an output float is of 32bit and cast to i64 otherwise. 
@eric-k256 @GeorgeARM 

>From 9d21ac711144864c892927f466b7289e991c3365 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Mon, 11 Dec 2023 12:04:43 +0000
Subject: [PATCH] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d

Lowering of rfft2d to linalg now uses index to i32 cast
if an output float is of 32bit and 64bit otherwise.
---
 .../Conversion/TosaToLinalg/TosaToLinalg.cpp  |  7 ++-
 .../TosaToLinalg/tosa-to-linalg.mlir          | 48 +++++++++----------
 2 files changed, 29 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 3bf7bf12b5e96f..f0e4a00bdc5179 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2284,8 +2284,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
 
   static Value castIndexToFloat(OpBuilder &builder, Location loc,
                                 FloatType type, Value value) {
-    auto integerVal =
-        builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+    auto integerVal = builder.create<arith::IndexCastUIOp>(
+        loc,
+        32 <= type.getIntOrFloatBitWidth() ? builder.getI32Type()
+                                           : builder.getI64Type(),
+        value);
 
     return builder.create<arith::UIToFPOp>(loc, type, integerVal);
   }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da68..6f9597dd399af6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1691,10 +1691,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:   %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
 // CHECK:   %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
 // CHECK:   %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK:   %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK:   %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK:   %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK:   %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1702,17 +1702,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
 // CHECK:     outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK:     %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK:     %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK:     %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK:     %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
 // CHECK:     %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
 // CHECK:     %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1761,10 +1761,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:   %[[CST_2:.*]] = arith.constant 2 : index
 // CHECK:   %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
 // CHECK:   %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK:   %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK:   %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK:   %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK:   %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
 // CHECK:   linalg.generic {
 // CHECK:     indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
 // CHECK:     iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1772,17 +1772,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
 // CHECK:     outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
 // CHECK:   ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
 // CHECK:     %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK:     %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK:     %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
 // CHECK:     %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK:     %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK:     %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
 // CHECK:     %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK:     %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK:     %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
 // CHECK:     %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK:     %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK:     %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
 // CHECK:     %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
 // CHECK:     %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
 // CHECK:     %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32



More information about the Mlir-commits mailing list