[Mlir-commits] [mlir] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d (PR #75098)
Dmitriy Smirnov
llvmlistbot at llvm.org
Wed Jan 3 09:38:12 PST 2024
https://github.com/d-smirnov updated https://github.com/llvm/llvm-project/pull/75098
>From f960c10f5ce584afe797fa5b5061915748ba0ee3 Mon Sep 17 00:00:00 2001
From: Dmitriy Smirnov <dmitriy.smirnov at arm.com>
Date: Mon, 11 Dec 2023 12:04:43 +0000
Subject: [PATCH] [TOSA] Usage of 32bit integer for 'index to float' in rfft2d
Lowering of rfft2d to linalg now uses index to i32 cast
if an output float is of 32bit and 64bit otherwise.
---
.../Conversion/TosaToLinalg/TosaToLinalg.cpp | 7 ++-
.../TosaToLinalg/tosa-to-linalg.mlir | 48 +++++++++----------
2 files changed, 29 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c35389f1e9290..678081837b8138 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2235,8 +2235,11 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> {
static Value castIndexToFloat(OpBuilder &builder, Location loc,
FloatType type, Value value) {
- auto integerVal =
- builder.create<arith::IndexCastUIOp>(loc, builder.getI64Type(), value);
+ auto integerVal = builder.create<arith::IndexCastUIOp>(
+ loc,
+ type.getIntOrFloatBitWidth() > 32 ? builder.getI64Type()
+ : builder.getI32Type(),
+ value);
return builder.create<arith::UIToFPOp>(loc, type, integerVal);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 535316e5a37210..3931e454da2e22 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1629,10 +1629,10 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
// CHECK: %[[EMPTY_1:.*]] = tensor.empty() : tensor<5x5x5xf32>
// CHECK: %[[VAR_3:.*]] = linalg.fill ins(%[[CST_ZERO:.*]]: f32) outs(%[[EMPTY_1:.*]] : tensor<5x5x5xf32>) -> tensor<5x5x5xf32>
// CHECK: %[[CST_PI:.*]] = arith.constant 6.28318548 : f32
-// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i64
-// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i64 to f32
-// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i64
-// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i64 to f32
+// CHECK: %[[VAR_5:.*]] = arith.index_castui %[[CST_5:.*]] : index to i32
+// CHECK: %[[VAR_6:.*]] = arith.uitofp %[[VAR_5:.*]] : i32 to f32
+// CHECK: %[[VAR_7:.*]] = arith.index_castui %[[CST_8:.*]] : index to i32
+// CHECK: %[[VAR_8:.*]] = arith.uitofp %[[VAR_7:.*]] : i32 to f32
// CHECK: linalg.generic {
// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1640,17 +1640,17 @@ func.func @test_static_rfft2d(%arg0: tensor<5x5x8xf32>) -> (tensor<5x5x5xf32>, t
// CHECK: outs(%[[VAR_1]], %[[VAR_3]] : tensor<5x5x5xf32>, tensor<5x5x5xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i64 to f32
+// CHECK: %[[VAR_10:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK: %[[VAR_11:.*]] = arith.uitofp %[[VAR_10]] : i32 to f32
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i64 to f32
+// CHECK: %[[VAR_13:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK: %[[VAR_14:.*]] = arith.uitofp %[[VAR_13]] : i32 to f32
// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i64 to f32
+// CHECK: %[[VAR_16:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK: %[[VAR_17:.*]] = arith.uitofp %[[VAR_16]] : i32 to f32
// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i64 to f32
+// CHECK: %[[VAR_19:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK: %[[VAR_20:.*]] = arith.uitofp %[[VAR_19]] : i32 to f32
// CHECK: %[[VAR_21:.*]] = arith.mulf %[[VAR_17]], %[[VAR_11]] : f32
// CHECK: %[[VAR_22:.*]] = arith.mulf %[[VAR_20]], %[[VAR_14]] : f32
// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_21]], %[[VAR_6]] : f32
@@ -1699,10 +1699,10 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
// CHECK: %[[CST_2:.*]] = arith.constant 2 : index
// CHECK: %[[DIM_8:.*]] = tensor.dim %[[ARG_0]], %[[CST_2]] : tensor<?x?x?xf32>
// CHECK: %[[CST_9:.*]] = arith.constant 6.28318548 : f32
-// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i64
-// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i64 to f32
-// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i64
-// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i64 to f32
+// CHECK: %[[VAR_6:.*]] = arith.index_castui %[[DIM_6]] : index to i32
+// CHECK: %[[VAR_7:.*]] = arith.uitofp %[[VAR_6]] : i32 to f32
+// CHECK: %[[VAR_8:.*]] = arith.index_castui %[[DIM_8]] : index to i32
+// CHECK: %[[VAR_9:.*]] = arith.uitofp %[[VAR_8]] : i32 to f32
// CHECK: linalg.generic {
// CHECK: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]],
// CHECK: iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]}
@@ -1710,17 +1710,17 @@ func.func @test_dynamic_rfft2d(%arg0: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>,
// CHECK: outs(%[[VAR_3]], %[[VAR_5]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT_0:.*]]: f32, %[[OUT_1:.*]]: f32):
// CHECK: %[[INDEX_1:.*]] = linalg.index 1 : index
-// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i64
-// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i64 to f32
+// CHECK: %[[VAR_12:.*]] = arith.index_castui %[[INDEX_1]] : index to i32
+// CHECK: %[[VAR_13:.*]] = arith.uitofp %[[VAR_12]] : i32 to f32
// CHECK: %[[INDEX_2:.*]] = linalg.index 2 : index
-// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i64
-// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i64 to f32
+// CHECK: %[[VAR_15:.*]] = arith.index_castui %[[INDEX_2]] : index to i32
+// CHECK: %[[VAR_16:.*]] = arith.uitofp %[[VAR_15]] : i32 to f32
// CHECK: %[[INDEX_3:.*]] = linalg.index 3 : index
-// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i64
-// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i64 to f32
+// CHECK: %[[VAR_18:.*]] = arith.index_castui %[[INDEX_3]] : index to i32
+// CHECK: %[[VAR_19:.*]] = arith.uitofp %[[VAR_18]] : i32 to f32
// CHECK: %[[INDEX_4:.*]] = linalg.index 4 : index
-// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i64
-// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i64 to f32
+// CHECK: %[[VAR_21:.*]] = arith.index_castui %[[INDEX_4]] : index to i32
+// CHECK: %[[VAR_22:.*]] = arith.uitofp %[[VAR_21]] : i32 to f32
// CHECK: %[[VAR_23:.*]] = arith.mulf %[[VAR_19]], %[[VAR_13]] : f32
// CHECK: %[[VAR_24:.*]] = arith.mulf %[[VAR_22]], %[[VAR_16]] : f32
// CHECK: %[[XCOMP:.*]] = arith.divf %[[VAR_23]], %[[VAR_7]] : f32
More information about the Mlir-commits
mailing list