[Mlir-commits] [mlir] 60621cd - [mlir][tosa] Check for overflow in binary integer folders (#172695)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 20 02:58:44 PST 2026


Author: Luke Hutton
Date: 2026-01-20T10:58:40Z
New Revision: 60621cd2739c5b1ea3db4630fdd51cd8b3eaf57e

URL: https://github.com/llvm/llvm-project/commit/60621cd2739c5b1ea3db4630fdd51cd8b3eaf57e
DIFF: https://github.com/llvm/llvm-project/commit/60621cd2739c5b1ea3db4630fdd51cd8b3eaf57e.diff

LOG: [mlir][tosa] Check for overflow in binary integer folders  (#172695)

For these folders to be TOSA compliant, they need to check for overflow.
This commit adds those checks, subsequently preventing folding if an
overflow is detected.

This commit also fixes the greater/greater_equal folders to account for
unsigned types.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/test/Dialect/Tosa/constant_folding.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..b15a3a4279064 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -889,33 +889,101 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Operator Folders.
 //===----------------------------------------------------------------------===//
 
-template <typename IntFolder, typename FloatFolder>
+template <typename Folder>
 static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
                                       DenseElementsAttr rhs,
                                       RankedTensorType returnTy) {
   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+    const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+    const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
     if (lETy != rETy)
       return {};
 
-    if (llvm::isa<IntegerType>(lETy)) {
-      APInt l = lhs.getSplatValue<APInt>();
-      APInt r = rhs.getSplatValue<APInt>();
-      auto result = IntFolder()(l, r);
-      return DenseElementsAttr::get(returnTy, result);
+    if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+      const APInt l = lhs.getSplatValue<APInt>();
+      const APInt r = rhs.getSplatValue<APInt>();
+      const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
 
     if (llvm::isa<FloatType>(lETy)) {
-      APFloat l = lhs.getSplatValue<APFloat>();
-      APFloat r = rhs.getSplatValue<APFloat>();
-      auto result = FloatFolder()(l, r);
-      return DenseElementsAttr::get(returnTy, result);
+      const APFloat l = lhs.getSplatValue<APFloat>();
+      const APFloat r = rhs.getSplatValue<APFloat>();
+      const auto maybeResult = Folder::fold(l, r);
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
   }
 
   return {};
 }
+struct FoldAddAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    bool overflow;
+    const APInt result =
+        isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
+    if (overflow)
+      return failure();
+    return result;
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs + rhs;
+  }
+};
+
+struct FoldSubAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    bool overflow;
+    const APInt result =
+        isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
+    if (overflow)
+      return failure();
+    return result;
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs - rhs;
+  }
+};
+
+struct FoldGreaterAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
+  }
+
+  static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+    return APInt(1, lhs > rhs);
+  }
+};
+
+struct FoldGreaterEqualAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
+  }
+
+  static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+    return APInt(1, lhs >= rhs);
+  }
+};
+
+struct FoldEqualAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    return APInt(1, lhs == rhs);
+  }
+
+  static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+    return APInt(1, lhs == rhs);
+  }
+};
 
 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
   if (llvm::isa<FloatType>(elemType))
@@ -963,8 +1031,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
-                                                            resultTy);
+  return binaryFolder<FoldAddAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -1145,38 +1212,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return binaryFolder<FoldSubAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
-namespace {
-template <typename Cmp>
-struct ComparisonFold {
-  ComparisonFold() = default;
-  APInt operator()(const APInt &l, const APInt &r) {
-    return APInt(1, Cmp()(l, r));
-  }
-
-  APInt operator()(const APFloat &l, const APFloat &r) {
-    return APInt(1, Cmp()(l, r));
-  }
-};
-
-struct APIntFoldGreater {
-  APIntFoldGreater() = default;
-  APInt operator()(const APInt &l, const APInt &r) {
-    return APInt(1, l.sgt(r));
-  }
-};
-
-struct APIntFoldGreaterEqual {
-  APIntFoldGreaterEqual() = default;
-  APInt operator()(const APInt &l, const APInt &r) {
-    return APInt(1, l.sge(r));
-  }
-};
-} // namespace
-
 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
   auto lhsAttr =
@@ -1187,8 +1225,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
-      lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1201,9 +1238,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<APIntFoldGreaterEqual,
-                      ComparisonFold<std::greater_equal<APFloat>>>(
-      lhsAttr, rhsAttr, resultTy);
+  return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1226,9 +1261,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
   if (!lhsAttr || !rhsAttr)
     return {};
 
-  return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
-                      ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
-                                                              resultTy);
+  return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
 }
 
 OpFoldResult CastOp::fold(FoldAdaptor adaptor) {

diff  --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 68bfd4c7e4980..8c375b6c528ef 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -92,6 +92,50 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_add_splat_i32_positive_overflow
+func.func @fold_add_splat_i32_positive_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.add
+  %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_i32_negative_overflow
+func.func @fold_add_splat_i32_negative_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.add
+  %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8
+func.func @fold_add_splat_ui8() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: "tosa.const"() <{values = dense<255> : tensor<10xui8>}> : () -> tensor<10xui8>
+  %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %add : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8_overflow
+func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<2> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: tosa.add
+  %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %add : tensor<10xui8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_div_zero_lhs_i32
 func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -288,6 +332,50 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
 
 // -----
 
+// CHECK-LABEL: @fold_sub_splat_i32_positive_overflow
+func.func @fold_sub_splat_i32_positive_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.sub
+  %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_i32_negative_overflow
+func.func @fold_sub_splat_i32_negative_overflow() -> tensor<10xi32> {
+  %one = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+  %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+  // CHECK: tosa.sub
+  %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+  return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8
+func.func @fold_sub_splat_ui8() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: "tosa.const"() <{values = dense<2> : tensor<10xui8>}> : () -> tensor<10xui8>
+  %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %sub : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8_overflow
+func.func @fold_sub_splat_ui8_overflow() -> tensor<10xui8> {
+  %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+  // CHECK: tosa.sub
+  %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+  return %sub : tensor<10xui8>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_greater_splat_f32
 func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -320,6 +408,23 @@ func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
 
 // -----
 
+// CHECK-LABEL: @fold_greater_splat_ui8
+func.func @fold_greater_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %1 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %2 = "tosa.const"() {values = dense<246> : tensor<10xui8>} : () -> tensor<10xui8>
+  %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+  %true = tosa.greater %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  %false = tosa.greater %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  %false2 = tosa.greater %0, %2 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]], %[[FALSE]]
+  return %true, %false, %false2 : tensor<10xi1>, tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_greater_eq_splat_f32
 func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -352,6 +457,22 @@ func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
 
 // -----
 
+// CHECK-LABEL: @fold_greater_eq_splat_ui8
+func.func @fold_greater_eq_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>) {
+  %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+  %1 = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+  %2 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+  %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+  %true = tosa.greater_equal %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  %false = tosa.greater_equal %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+  // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+  // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+  // CHECK: return %[[TRUE]], %[[FALSE]]
+  return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
 // CHECK-LABEL: @fold_eq_splat_f32
 func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
   %0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>


        


More information about the Mlir-commits mailing list