[Mlir-commits] [mlir] [mlir][tosa] Fix integer-to-boolean cast folder (PR #150370)
Longsheng Mou
llvmlistbot at llvm.org
Thu Jul 24 05:13:22 PDT 2025
https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/150370
>From b7f11d0864f3ee71c3a8348ca8cb9d6f8560fbc3 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 12:24:26 +0800
Subject: [PATCH 1/4] [mlir][tosa] Fix integer-to-boolean cast folder
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
According to the TOSA spec, casting to boolean should produce true if the input is non-zero, and false otherwise — i.e., `out = (in != 0) ? true : false`. Previous behavior incorrectly relied on truncation, which could yield incorrect results for non-zero values whose least significant bit is zero.
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 606626dfe4d2c..080955bf94761 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1302,9 +1302,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto intVal = operand.getSplatValue<APInt>();
auto bitwidth = outETy.getIntOrFloatBitWidth();
+ // i1 types are boolean in TOSA
if (trunc) {
- intVal = intVal.trunc(bitwidth);
- // i1 types are boolean in TOSA
+ if (outETy.isInteger(1)) {
+ intVal = intVal.isZero() ? APInt(bitwidth, 0) : APInt(bitwidth, 1);
+ } else {
+ intVal = intVal.trunc(bitwidth);
+ }
} else if (unsignIn || inIntType.isInteger(1)) {
intVal = intVal.zext(bitwidth);
} else {
>From 304fa75706cccf208fd0b0cab249090cb7688900 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 12:28:57 +0800
Subject: [PATCH 2/4] add test
---
mlir/test/Dialect/Tosa/canonicalize.mlir | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 11c8d54fda055..6b55442a82a0a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1349,3 +1349,14 @@ func.func @test_fold_i1_to_i32_cast() -> tensor<i32> {
%1 = "tosa.cast"(%0) : (tensor<i1>) -> tensor<i32>
return %1 : tensor<i32>
}
+
+// -----
+
+// CHECK-LABEL: @test_fold_i32_to_i1_cast
+// CHECK: %[[OUT:.*]] = "tosa.const"() <{values = dense<true> : tensor<i1>}> : () -> tensor<i1>
+// CHECK: return %[[OUT]] : tensor<i1>
+func.func @test_fold_i32_to_i1_cast() -> tensor<i1> {
+ %0 = "tosa.const"() <{values = dense<10> : tensor<i32>}> : () -> tensor<i32>
+ %1 = "tosa.cast"(%0) : (tensor<i32>) -> tensor<i1>
+ return %1 : tensor<i1>
+}
>From e7a8c0cca5354071a1ea04c596d3598a11213945 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 14:13:10 +0800
Subject: [PATCH 3/4] simplify code
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 080955bf94761..dd99bb8f8e5c3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1305,7 +1305,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
// i1 types are boolean in TOSA
if (trunc) {
if (outETy.isInteger(1)) {
- intVal = intVal.isZero() ? APInt(bitwidth, 0) : APInt(bitwidth, 1);
+ intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
} else {
intVal = intVal.trunc(bitwidth);
}
>From 05a84d151c3a8615fa280705fc21264610c7a41d Mon Sep 17 00:00:00 2001
From: Longsheng Mou <longshengmou at gmail.com>
Date: Thu, 24 Jul 2025 20:13:13 +0800
Subject: [PATCH 4/4] reduce nest
---
mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index dd99bb8f8e5c3..34e7e4200cd44 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1303,12 +1303,10 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
auto bitwidth = outETy.getIntOrFloatBitWidth();
// i1 types are boolean in TOSA
- if (trunc) {
- if (outETy.isInteger(1)) {
- intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
- } else {
- intVal = intVal.trunc(bitwidth);
- }
+ if (outETy.isInteger(1)) {
+ intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1);
+ } else if (trunc) {
+ intVal = intVal.trunc(bitwidth);
} else if (unsignIn || inIntType.isInteger(1)) {
intVal = intVal.zext(bitwidth);
} else {
More information about the Mlir-commits
mailing list