[Mlir-commits] [mlir] [mlir][tosa] Improve folder conformance to TOSA specification (PR #200223)
Luke Hutton
llvmlistbot at llvm.org
Thu May 28 09:50:21 PDT 2026
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/200223
This commit fixes some bugs in TOSA folders that cause non-conformant results. The fixes include:
- tosa.intdiv - Folding when the lhs and rhs are zero. In the TOSA specification this is undefined behaviour.
- tosa.div_ceil_shape/tosa.div_floor_shape - Folding when the lhs is negative or the rhs is non-positive. In the TOSA specification this is undefined behaviour.
In addition, some test cases have been added for non-exercised code paths, including:
- tosa.intdiv - Rejects overflow cases
- tosa.greater/tosa.greater_equal/tosa.equal - Correctly evaluates NaN cases to False.
- tosa.cast - Saturating rounding when input is out of range of the output type.
- tosa.mod_shape - Rejects cases where lhs is negative or rhs is non-positive.
>From b6c4b9ca5f3e619a5d09882b93cae900f8aa746f Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 26 May 2026 20:54:44 +0100
Subject: [PATCH] [mlir][tosa] Improve folder conformance to TOSA specification
This commmit fixes some bugs in TOSA folders that cause
non-conformant results. The fixes include:
- tosa.intdiv - Folding when the lhs and rhs are zero. In the TOSA
specification this is undefined behaviour.
- tosa.div_ceil_shape/tosa.div_floor_shape - Folding when the lhs
is negative or the rhs is non-positive. In the TOSA specification
this is undefined behaviour.
In addition, some test cases have been added for non-exercised code
paths, including:
- tosa.intdiv - Rejects overflow cases
- tosa.greater/tosa.greater_equal/tosa.equal - Correctly evaluates NaN
cases to False.
- tosa.cast - Saturating rounding when input is out of range of the
output type.
- tosa.mod_shape - Rejects cases where lhs is negative or rhs is
non-positive.
Change-Id: I3ffaaeab700973eb167c91a23a39be6a90c842e4
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 28 ++-
mlir/test/Dialect/Tosa/constant_folding.mlir | 180 +++++++++++++++++-
2 files changed, 197 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4af185a6e534b..bdb4370146458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1462,6 +1462,24 @@ struct Exp2FoldAdaptor {
}
};
+// The specification requires shape div operations to have non-negative lhs and
+// strictly positive rhs so we can only fold when these conditions are met.
+template <bool Ceil>
+struct ShapeDivFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ assert(!isUnsigned &&
+ "unsigned values are not supported for shape div folders");
+ if (lhs.isNegative() || !rhs.isStrictlyPositive())
+ return failure();
+ return DivFoldAdaptor<Ceil>::fold(lhs, rhs, isUnsigned);
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return failure();
+ }
+};
+
struct Log2CeilFoldAdaptor {
static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
if (!value.isStrictlyPositive())
@@ -1595,10 +1613,12 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- if (lhsAttr && lhsAttr.isSplat()) {
+ if (lhsAttr && lhsAttr.isSplat() && rhsAttr && rhsAttr.isSplat()) {
if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
- lhsAttr.getSplatValue<APInt>().isZero())
+ lhsAttr.getSplatValue<APInt>().isZero() &&
+ !rhsAttr.getSplatValue<APInt>().isZero()) {
return lhsAttr.resizeSplat(resultTy);
+ }
}
if (rhsAttr && rhsAttr.isSplat()) {
@@ -2411,11 +2431,11 @@ OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
- return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
+ return binaryFold<DivCeilShapeOp, ShapeDivFoldAdaptor</*Ceil*/ true>>(this);
}
OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
- return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
+ return binaryFold<DivFloorShapeOp, ShapeDivFoldAdaptor</*Ceil*/ false>>(this);
}
OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 30ba340afaa2d..118746704d4c4 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -226,11 +226,22 @@ func.func @fold_dynamic_add_broadcast_zero_rhs(%arg0: tensor<?x17xi32>) -> tenso
// -----
-// CHECK-LABEL: @fold_div_zero_lhs_i32
-func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+// CHECK-LABEL: @no_fold_div_zero_lhs_i32
+func.func @no_fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
- // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+ // CHECK: tosa.intdiv
%div = tosa.intdiv %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_zero_lhs_nonzero_splat_rhs_i32
+func.func @fold_div_zero_lhs_nonzero_splat_rhs_i32() -> tensor<i32> {
+ %lhs = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %rhs = "tosa.const"() {values = dense<2> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+ %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: return %[[ZERO]]
return %div : tensor<i32>
}
@@ -247,12 +258,11 @@ func.func @no_fold_dynamic_div_zero_lhs(%arg0: tensor<?x4xi32>) -> tensor<?x4xi3
// -----
-// CHECK-LABEL: @fold_div_zero_lhs_broadcast
-func.func @fold_div_zero_lhs_broadcast(%arg0: tensor<2x4xi32>) -> tensor<2x4xi32> {
+// CHECK-LABEL: @no_fold_div_zero_lhs_broadcast
+func.func @no_fold_div_zero_lhs_broadcast(%arg0: tensor<2x4xi32>) -> tensor<2x4xi32> {
%zero = "tosa.const"() {values = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
- // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> : tensor<2x4xi32>
+ // CHECK: tosa.intdiv
%div = tosa.intdiv %zero, %arg0 : (tensor<1x1xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
- // CHECK: return %[[ZERO]]
return %div : tensor<2x4xi32>
}
@@ -310,6 +320,38 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
// -----
+// CHECK-LABEL: @no_fold_div_splat_i32_overflow
+func.func @no_fold_div_splat_i32_overflow() -> tensor<i32> {
+ %lhs = "tosa.const"() {values = dense<-2147483648> : tensor<i32>} : () -> tensor<i32>
+ %rhs = "tosa.const"() {values = dense<-1> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: tosa.intdiv
+ %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_div_zero_lhs_zero_rhs_i32
+func.func @no_fold_div_zero_lhs_zero_rhs_i32() -> tensor<i32> {
+ %lhs = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %rhs = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // CHECK: tosa.intdiv
+ %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_div_zero_lhs_non_splat_rhs_i32
+func.func @no_fold_div_zero_lhs_non_splat_rhs_i32() -> tensor<2xi32> {
+ %lhs = "tosa.const"() {values = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+ %rhs = "tosa.const"() {values = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+ // CHECK: tosa.intdiv
+ %div = tosa.intdiv %lhs, %rhs : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+ return %div : tensor<2xi32>
+}
+
+// -----
// CHECK-LABEL: @fold_mul_zero_rhs_f32
func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
@@ -796,6 +838,20 @@ func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
// -----
+// CHECK-LABEL: @fold_compare_nan_f32
+func.func @fold_compare_nan_f32() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) {
+ %nan = "tosa.const"() {values = dense<0x7FC00000> : tensor<10xf32>} : () -> tensor<10xf32>
+ %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+ %gt = tosa.greater %nan, %one : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+ %ge = tosa.greater_equal %one, %nan : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+ %eq = tosa.equal %nan, %nan : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+ // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[FALSE]], %[[FALSE]], %[[FALSE]]
+ return %gt, %ge, %eq : tensor<10xi1>, tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
// CHECK-LABEL: @fold_eq_splat_i32
func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
@@ -924,6 +980,50 @@ func.func @cast_float_to_int_round() -> tensor<i16> {
// -----
+// CHECK: func.func @cast_float_to_int_saturates_high
+func.func @cast_float_to_int_saturates_high() -> tensor<i8> {
+ %splat = "tosa.const"() {values = dense<1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<127> : tensor<i8>}
+ %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i8>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<i8>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_int_saturates_low
+func.func @cast_float_to_int_saturates_low() -> tensor<i8> {
+ %splat = "tosa.const"() {values = dense<-1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-128> : tensor<i8>}
+ %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i8>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<i8>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_unsigned_int_saturates_low
+func.func @cast_float_to_unsigned_int_saturates_low() -> tensor<ui8> {
+ %splat = "tosa.const"() {values = dense<-1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<0> : tensor<ui8>}
+ %cast = tosa.cast %splat : (tensor<f32>) -> tensor<ui8>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<ui8>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_unsigned_int_saturates_high
+func.func @cast_float_to_unsigned_int_saturates_high() -> tensor<ui8> {
+ %splat = "tosa.const"() {values = dense<1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+ // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<255> : tensor<ui8>}
+ %cast = tosa.cast %splat : (tensor<f32>) -> tensor<ui8>
+ // CHECK: return %[[SPLAT]]
+ return %cast : tensor<ui8>
+}
+
+// -----
+
// CHECK: func.func @cast_int_to_int_trunc
func.func @cast_int_to_int_trunc() -> tensor<i16> {
%splat = "tosa.const"() {values = dense<-1> : tensor<i32>} : () -> tensor<i32>
@@ -1234,6 +1334,28 @@ func.func @test_fold_div_ceil_shape() -> !tosa.shape<6> {
// -----
+// CHECK-LABEL: @test_no_fold_div_ceil_shape_negative_input
+// CHECK: tosa.div_ceil_shape
+func.func @test_no_fold_div_ceil_shape_negative_input() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, -7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_div_ceil_shape_negative_divisor
+// CHECK: tosa.div_ceil_shape
+func.func @test_no_fold_div_ceil_shape_negative_divisor() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, 7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, -2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
// CHECK-LABEL: @test_no_fold_div_ceil_shape_positive_overflow
// CHECK: tosa.div_ceil_shape
func.func @test_no_fold_div_ceil_shape_positive_overflow() -> !tosa.shape<6> {
@@ -1267,6 +1389,28 @@ func.func @test_fold_div_floor_shape() -> !tosa.shape<6> {
// -----
+// CHECK-LABEL: @test_no_fold_div_floor_shape_negative_input
+// CHECK: tosa.div_floor_shape
+func.func @test_no_fold_div_floor_shape_negative_input() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, -7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.div_floor_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_div_floor_shape_negative_divisor
+// CHECK: tosa.div_floor_shape
+func.func @test_no_fold_div_floor_shape_negative_divisor() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, 7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, -2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.div_floor_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
// CHECK-LABEL: @test_no_fold_div_floor_shape_positive_overflow
// CHECK: tosa.div_floor_shape
func.func @test_no_fold_div_floor_shape_positive_overflow() -> !tosa.shape<6> {
@@ -1300,6 +1444,28 @@ func.func @test_fold_mod_shape() -> !tosa.shape<6> {
// -----
+// CHECK-LABEL: @test_no_fold_mod_shape_negative_input
+// CHECK: tosa.mod_shape
+func.func @test_no_fold_mod_shape_negative_input() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, -7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_mod_shape_negative_divisor
+// CHECK: tosa.mod_shape
+func.func @test_no_fold_mod_shape_negative_divisor() -> !tosa.shape<6> {
+ %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, -2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+ %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+ return %c : !tosa.shape<6>
+}
+
+// -----
+
// CHECK-LABEL: @test_no_fold_mod_shape_positive_overflow
// CHECK: tosa.mod_shape
func.func @test_no_fold_mod_shape_positive_overflow() -> !tosa.shape<6> {
More information about the Mlir-commits
mailing list