[Mlir-commits] [mlir] [mlir][tosa] Improve folder conformance to TOSA specification (PR #200223)

Luke Hutton llvmlistbot at llvm.org
Thu May 28 09:50:21 PDT 2026


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/200223

This commit fixes some bugs in TOSA folders that cause non-conformant results. The fixes include:
- tosa.intdiv - Folding when the lhs and rhs are zero. In the TOSA specification this is undefined behaviour.
- tosa.div_ceil_shape/tosa.div_floor_shape - Folding when the lhs is negative or the rhs is non-positive. In the TOSA specification this is undefined behaviour.

In addition, some test cases have been added for non-exercised code paths, including:
- tosa.intdiv - Rejects overflow cases
- tosa.greater/tosa.greater_equal/tosa.equal - Correctly evaluates NaN cases to False.
- tosa.cast - Saturating rounding when input is out of range of the output type.
- tosa.mod_shape - Rejects cases where lhs is negative or rhs is non-positive.

>From b6c4b9ca5f3e619a5d09882b93cae900f8aa746f Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Tue, 26 May 2026 20:54:44 +0100
Subject: [PATCH] [mlir][tosa] Improve folder conformance to TOSA specification

This commmit fixes some bugs in TOSA folders that cause
non-conformant results. The fixes include:
- tosa.intdiv - Folding when the lhs and rhs are zero. In the TOSA
  specification this is undefined behaviour.
- tosa.div_ceil_shape/tosa.div_floor_shape - Folding when the lhs
  is negative or the rhs is non-positive. In the TOSA specification
  this is undefined behaviour.

In addition, some test cases have been added for non-exercised code
paths, including:
- tosa.intdiv - Rejects overflow cases
- tosa.greater/tosa.greater_equal/tosa.equal - Correctly evaluates NaN
  cases to False.
- tosa.cast - Saturating rounding when input is out of range of the
  output type.
- tosa.mod_shape - Rejects cases where lhs is negative or rhs is
  non-positive.

Change-Id: I3ffaaeab700973eb167c91a23a39be6a90c842e4
---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp |  28 ++-
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 180 +++++++++++++++++-
 2 files changed, 197 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 4af185a6e534b..bdb4370146458 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1462,6 +1462,24 @@ struct Exp2FoldAdaptor {
   }
 };
 
+// The specification requires shape div operations to have non-negative lhs and
+// strictly positive rhs so we can only fold when these conditions are met.
+template <bool Ceil>
+struct ShapeDivFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               bool isUnsigned) {
+    assert(!isUnsigned &&
+           "unsigned values are not supported for shape div folders");
+    if (lhs.isNegative() || !rhs.isStrictlyPositive())
+      return failure();
+    return DivFoldAdaptor<Ceil>::fold(lhs, rhs, isUnsigned);
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return failure();
+  }
+};
+
 struct Log2CeilFoldAdaptor {
   static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
     if (!value.isStrictlyPositive())
@@ -1595,10 +1613,12 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
-  if (lhsAttr && lhsAttr.isSplat()) {
+  if (lhsAttr && lhsAttr.isSplat() && rhsAttr && rhsAttr.isSplat()) {
     if (llvm::isa<IntegerType>(resultETy) && resultTy.hasStaticShape() &&
-        lhsAttr.getSplatValue<APInt>().isZero())
+        lhsAttr.getSplatValue<APInt>().isZero() &&
+        !rhsAttr.getSplatValue<APInt>().isZero()) {
       return lhsAttr.resizeSplat(resultTy);
+    }
   }
 
   if (rhsAttr && rhsAttr.isSplat()) {
@@ -2411,11 +2431,11 @@ OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
-  return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
+  return binaryFold<DivCeilShapeOp, ShapeDivFoldAdaptor</*Ceil*/ true>>(this);
 }
 
 OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
-  return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
+  return binaryFold<DivFloorShapeOp, ShapeDivFoldAdaptor</*Ceil*/ false>>(this);
 }
 
 OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 30ba340afaa2d..118746704d4c4 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -226,11 +226,22 @@ func.func @fold_dynamic_add_broadcast_zero_rhs(%arg0: tensor<?x17xi32>) -> tenso
 
 // -----
 
-// CHECK-LABEL: @fold_div_zero_lhs_i32
-func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
+// CHECK-LABEL: @no_fold_div_zero_lhs_i32
+func.func @no_fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+  // CHECK: tosa.intdiv
   %div = tosa.intdiv %zero, %arg0 : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_div_zero_lhs_nonzero_splat_rhs_i32
+func.func @fold_div_zero_lhs_nonzero_splat_rhs_i32() -> tensor<i32> {
+  %lhs = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %rhs = "tosa.const"() {values = dense<2> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0>
+  %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %div : tensor<i32>
 }
@@ -247,12 +258,11 @@ func.func @no_fold_dynamic_div_zero_lhs(%arg0: tensor<?x4xi32>) -> tensor<?x4xi3
 
 // -----
 
-// CHECK-LABEL: @fold_div_zero_lhs_broadcast
-func.func @fold_div_zero_lhs_broadcast(%arg0: tensor<2x4xi32>) -> tensor<2x4xi32> {
+// CHECK-LABEL: @no_fold_div_zero_lhs_broadcast
+func.func @no_fold_div_zero_lhs_broadcast(%arg0: tensor<2x4xi32>) -> tensor<2x4xi32> {
   %zero = "tosa.const"() {values = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
-  // CHECK: %[[ZERO:.+]] = "tosa.const"() <{values = dense<0> : tensor<2x4xi32>
+  // CHECK: tosa.intdiv
   %div = tosa.intdiv %zero, %arg0 : (tensor<1x1xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
-  // CHECK: return %[[ZERO]]
   return %div : tensor<2x4xi32>
 }
 
@@ -310,6 +320,38 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 
 // -----
 
+// CHECK-LABEL: @no_fold_div_splat_i32_overflow
+func.func @no_fold_div_splat_i32_overflow() -> tensor<i32> {
+  %lhs = "tosa.const"() {values = dense<-2147483648> : tensor<i32>} : () -> tensor<i32>
+  %rhs = "tosa.const"() {values = dense<-1> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: tosa.intdiv
+  %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_div_zero_lhs_zero_rhs_i32
+func.func @no_fold_div_zero_lhs_zero_rhs_i32() -> tensor<i32> {
+  %lhs = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %rhs = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // CHECK: tosa.intdiv
+  %div = tosa.intdiv %lhs, %rhs : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  return %div : tensor<i32>
+}
+
+// -----
+
+// CHECK-LABEL: @no_fold_div_zero_lhs_non_splat_rhs_i32
+func.func @no_fold_div_zero_lhs_non_splat_rhs_i32() -> tensor<2xi32> {
+  %lhs = "tosa.const"() {values = dense<0> : tensor<2xi32>} : () -> tensor<2xi32>
+  %rhs = "tosa.const"() {values = dense<[1, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
+  // CHECK: tosa.intdiv
+  %div = tosa.intdiv %lhs, %rhs : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
+  return %div : tensor<2xi32>
+}
+
+// -----
 
 // CHECK-LABEL: @fold_mul_zero_rhs_f32
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
@@ -796,6 +838,20 @@ func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
 
 // -----
 
+// CHECK-LABEL: @fold_compare_nan_f32
+func.func @fold_compare_nan_f32() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) {
+  %nan = "tosa.const"() {values = dense<0x7FC00000> : tensor<10xf32>} : () -> tensor<10xf32>
+  %one = "tosa.const"() {values = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
+  %gt = tosa.greater %nan, %one : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %ge = tosa.greater_equal %one, %nan : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  %eq = tosa.equal %nan, %nan : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[FALSE]], %[[FALSE]], %[[FALSE]]
+  return %gt, %ge, %eq : tensor<10xi1>, tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_eq_splat_i32
 func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
@@ -924,6 +980,50 @@ func.func @cast_float_to_int_round() -> tensor<i16> {
 
 // -----
 
+// CHECK: func.func @cast_float_to_int_saturates_high
+func.func @cast_float_to_int_saturates_high() -> tensor<i8> {
+  %splat = "tosa.const"() {values = dense<1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<127> : tensor<i8>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i8>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i8>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_int_saturates_low
+func.func @cast_float_to_int_saturates_low() -> tensor<i8> {
+  %splat = "tosa.const"() {values = dense<-1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<-128> : tensor<i8>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<i8>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<i8>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_unsigned_int_saturates_low
+func.func @cast_float_to_unsigned_int_saturates_low() -> tensor<ui8> {
+  %splat = "tosa.const"() {values = dense<-1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<0> : tensor<ui8>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<ui8>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<ui8>
+}
+
+// -----
+
+// CHECK: func.func @cast_float_to_unsigned_int_saturates_high
+func.func @cast_float_to_unsigned_int_saturates_high() -> tensor<ui8> {
+  %splat = "tosa.const"() {values = dense<1.000000e+20> : tensor<f32>} : () -> tensor<f32>
+  // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{values = dense<255> : tensor<ui8>}
+  %cast = tosa.cast %splat : (tensor<f32>) -> tensor<ui8>
+  // CHECK: return %[[SPLAT]]
+  return %cast : tensor<ui8>
+}
+
+// -----
+
 // CHECK: func.func @cast_int_to_int_trunc
 func.func @cast_int_to_int_trunc() -> tensor<i16> {
   %splat = "tosa.const"() {values = dense<-1> : tensor<i32>} : () -> tensor<i32>
@@ -1234,6 +1334,28 @@ func.func @test_fold_div_ceil_shape() -> !tosa.shape<6> {
 
 // -----
 
+// CHECK-LABEL: @test_no_fold_div_ceil_shape_negative_input
+// CHECK: tosa.div_ceil_shape
+func.func @test_no_fold_div_ceil_shape_negative_input() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, -7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_div_ceil_shape_negative_divisor
+// CHECK: tosa.div_ceil_shape
+func.func @test_no_fold_div_ceil_shape_negative_divisor() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, 7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, -2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.div_ceil_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
 // CHECK-LABEL: @test_no_fold_div_ceil_shape_positive_overflow
 // CHECK: tosa.div_ceil_shape
 func.func @test_no_fold_div_ceil_shape_positive_overflow() -> !tosa.shape<6> {
@@ -1267,6 +1389,28 @@ func.func @test_fold_div_floor_shape() -> !tosa.shape<6> {
 
 // -----
 
+// CHECK-LABEL: @test_no_fold_div_floor_shape_negative_input
+// CHECK: tosa.div_floor_shape
+func.func @test_no_fold_div_floor_shape_negative_input() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, -7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_div_floor_shape_negative_divisor
+// CHECK: tosa.div_floor_shape
+func.func @test_no_fold_div_floor_shape_negative_divisor() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[2, 7, 11, 22, 47, 7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[1, 2, 2, 3, 5, -2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.div_floor_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
 // CHECK-LABEL: @test_no_fold_div_floor_shape_positive_overflow
 // CHECK: tosa.div_floor_shape
 func.func @test_no_fold_div_floor_shape_positive_overflow() -> !tosa.shape<6> {
@@ -1300,6 +1444,28 @@ func.func @test_fold_mod_shape() -> !tosa.shape<6> {
 
 // -----
 
+// CHECK-LABEL: @test_no_fold_mod_shape_negative_input
+// CHECK: tosa.mod_shape
+func.func @test_no_fold_mod_shape_negative_input() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, -7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, 2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
+// CHECK-LABEL: @test_no_fold_mod_shape_negative_divisor
+// CHECK: tosa.mod_shape
+func.func @test_no_fold_mod_shape_negative_divisor() -> !tosa.shape<6> {
+  %a = tosa.const_shape {values = dense<[24, 7, 65, 33, 39, 7]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %b = tosa.const_shape {values = dense<[11, 2, 12, 13, 15, -2]> : tensor<6xindex>} : () -> !tosa.shape<6>
+  %c = tosa.mod_shape %a, %b : (!tosa.shape<6>, !tosa.shape<6>) -> !tosa.shape<6>
+  return %c : !tosa.shape<6>
+}
+
+// -----
+
 // CHECK-LABEL: @test_no_fold_mod_shape_positive_overflow
 // CHECK: tosa.mod_shape
 func.func @test_no_fold_mod_shape_positive_overflow() -> !tosa.shape<6> {



More information about the Mlir-commits mailing list