[Mlir-commits] [mlir] [mlir][tosa] Use roundeven in TOSA cast splat constant op folding (PR #99484)
Corentin Ferry
llvmlistbot at llvm.org
Thu Jul 18 05:50:22 PDT 2024
https://github.com/cferry-AMD updated https://github.com/llvm/llvm-project/pull/99484
>From c5b4b37d0e0dcadbc7a2e1635946df806fe58851 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Wed, 19 Jun 2024 10:05:54 +0200
Subject: [PATCH 1/2] Use roundeven in TOSA cast splat constant op folding
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +-
mlir/test/Dialect/Tosa/constant-op-fold.mlir | 11 +++++++++++
2 files changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 866ab0d2228f7..f4d4df5fd0031 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -764,7 +764,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
auto floatVal = operand.getSplatValue<APFloat>();
bool exact;
- floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact);
+ floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven, &exact);
return SplatElementsAttr::get(outTy, intVal);
}
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index ad82c9f8858e6..8e19f87dbf4aa 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -566,6 +566,17 @@ func.func @cast_float_to_int() -> tensor<i16> {
// -----
+// CHECK: func.func @cast_float_to_int_round
+func.func @cast_float_to_int_round() -> tensor<i16> {
+ %splat = "tosa.const"() {value = dense<-3.5> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{value = dense<-4> : tensor<i16>}
+ %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i16>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<i16>
+}
+
+// -----
+
// CHECK: func.func @cast_int_to_int_trunc
func.func @cast_int_to_int_trunc() -> tensor<i16> {
%splat = "tosa.const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
>From eed5a85098cae8739fc7abcb4b3fa688cef276f1 Mon Sep 17 00:00:00 2001
From: Corentin Ferry <corentin.ferry at amd.com>
Date: Thu, 18 Jul 2024 13:50:01 +0100
Subject: [PATCH 2/2] Correct patch formatting
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index f4d4df5fd0031..da9a93feac4d6 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -764,7 +764,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
llvm::cast<IntegerType>(outETy).getIntOrFloatBitWidth(), unsign);
auto floatVal = operand.getSplatValue<APFloat>();
bool exact;
- floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven, &exact);
+ floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven,
+ &exact);
return SplatElementsAttr::get(outTy, intVal);
}
More information about the Mlir-commits
mailing list