[Mlir-commits] [mlir] [mlir][tosa] Check for overflow in binary integer folders (PR #172695)
Luke Hutton
llvmlistbot at llvm.org
Fri Jan 16 09:00:13 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/172695
>From 6a1c24756f33536485ff67511d3e9dfa969e4b98 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 17 Dec 2025 13:28:15 +0000
Subject: [PATCH] [mlir][tosa] Check for overflow in integer folders
For these folders to be TOSA compliant, they need to check
for overflow. This commit adds those checks, subsequently
preventing folding if an overflow is detected.
This commit also fixes the greater/greater_equal folders
to account for unsigned types.
Change-Id: I2b5a5b92fb840d6c34a1f2faa18ae68a20d0ecdf
---
.../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 137 +++++++++++-------
mlir/test/Dialect/Tosa/constant_folding.mlir | 121 ++++++++++++++++
2 files changed, 206 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c420a4c9596ff..654f65cca0fe2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -889,33 +889,101 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
// Operator Folders.
//===----------------------------------------------------------------------===//
-template <typename IntFolder, typename FloatFolder>
+template <typename Folder>
static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
DenseElementsAttr rhs,
RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
if (lETy != rETy)
return {};
- if (llvm::isa<IntegerType>(lETy)) {
- APInt l = lhs.getSplatValue<APInt>();
- APInt r = rhs.getSplatValue<APInt>();
- auto result = IntFolder()(l, r);
- return DenseElementsAttr::get(returnTy, result);
+ if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+ const APInt l = lhs.getSplatValue<APInt>();
+ const APInt r = rhs.getSplatValue<APInt>();
+ const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
}
if (llvm::isa<FloatType>(lETy)) {
- APFloat l = lhs.getSplatValue<APFloat>();
- APFloat r = rhs.getSplatValue<APFloat>();
- auto result = FloatFolder()(l, r);
- return DenseElementsAttr::get(returnTy, result);
+ const APFloat l = lhs.getSplatValue<APFloat>();
+ const APFloat r = rhs.getSplatValue<APFloat>();
+ const auto maybeResult = Folder::fold(l, r);
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
}
}
return {};
}
+struct AddFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ bool overflow;
+ const APInt result = isUnsigned ? lhs.uadd_ov(rhs, overflow)
+ : lhs.sadd_ov(rhs, overflow);
+ if (overflow)
+ return failure();
+ return result;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs + rhs;
+ }
+};
+
+struct SubFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ bool overflow;
+ const APInt result = isUnsigned ? lhs.usub_ov(rhs, overflow)
+ : lhs.ssub_ov(rhs, overflow);
+ if (overflow)
+ return failure();
+ return result;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs - rhs;
+ }
+};
+
+struct FoldGreaterAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs > rhs);
+ }
+};
+
+struct FoldGreaterEqualAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs >= rhs);
+ }
+};
+
+struct FoldEqualAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return APInt(1, lhs == rhs);
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs == rhs);
+ }
+};
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
@@ -963,8 +1031,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -1145,38 +1212,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
-namespace {
-template <typename Cmp>
-struct ComparisonFold {
- ComparisonFold() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, Cmp()(l, r));
- }
-
- APInt operator()(const APFloat &l, const APFloat &r) {
- return APInt(1, Cmp()(l, r));
- }
-};
-
-struct APIntFoldGreater {
- APIntFoldGreater() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, l.sgt(r));
- }
-};
-
-struct APIntFoldGreaterEqual {
- APIntFoldGreaterEqual() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, l.sge(r));
- }
-};
-} // namespace
-
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
@@ -1187,8 +1225,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
- lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<FoldGreaterAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1201,9 +1238,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreaterEqual,
- ComparisonFold<std::greater_equal<APFloat>>>(
- lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<FoldGreaterEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1226,9 +1261,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<FoldEqualAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index 68bfd4c7e4980..8c375b6c528ef 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -92,6 +92,50 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
// -----
+// CHECK-LABEL: @fold_add_splat_i32_positive_overflow
+func.func @fold_add_splat_i32_positive_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.add
+ %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_i32_negative_overflow
+func.func @fold_add_splat_i32_negative_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.add
+ %add = tosa.add %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %add : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8
+func.func @fold_add_splat_ui8() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: "tosa.const"() <{values = dense<255> : tensor<10xui8>}> : () -> tensor<10xui8>
+ %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %add : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_add_splat_ui8_overflow
+func.func @fold_add_splat_ui8_overflow() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<2> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<254> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: tosa.add
+ %add = tosa.add %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %add : tensor<10xui8>
+}
+
+// -----
+
// CHECK-LABEL: @fold_div_zero_lhs_i32
func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
%zero = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -288,6 +332,50 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
// -----
+// CHECK-LABEL: @fold_sub_splat_i32_positive_overflow
+func.func @fold_sub_splat_i32_positive_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<2147483647> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<-1> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.sub
+ %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_i32_negative_overflow
+func.func @fold_sub_splat_i32_negative_overflow() -> tensor<10xi32> {
+ %one = "tosa.const"() {values = dense<-2147483648> : tensor<10xi32>} : () -> tensor<10xi32>
+ %two = "tosa.const"() {values = dense<1> : tensor<10xi32>} : () -> tensor<10xi32>
+ // CHECK: tosa.sub
+ %sub = tosa.sub %one, %two : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi32>
+ return %sub : tensor<10xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8
+func.func @fold_sub_splat_ui8() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: "tosa.const"() <{values = dense<2> : tensor<10xui8>}> : () -> tensor<10xui8>
+ %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %sub : tensor<10xui8>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_sub_splat_ui8_overflow
+func.func @fold_sub_splat_ui8_overflow() -> tensor<10xui8> {
+ %one = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %two = "tosa.const"() {values = dense<253> : tensor<10xui8>} : () -> tensor<10xui8>
+ // CHECK: tosa.sub
+ %sub = tosa.sub %one, %two : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xui8>
+ return %sub : tensor<10xui8>
+}
+
+// -----
+
// CHECK-LABEL: @fold_greater_splat_f32
func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -320,6 +408,23 @@ func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
// -----
+// CHECK-LABEL: @fold_greater_splat_ui8
+func.func @fold_greater_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>, tensor<10xi1>) {
+ %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %1 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %2 = "tosa.const"() {values = dense<246> : tensor<10xui8>} : () -> tensor<10xui8>
+ %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+ %true = tosa.greater %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ %false = tosa.greater %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ %false2 = tosa.greater %0, %2 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+ // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[TRUE]], %[[FALSE]], %[[FALSE]]
+ return %true, %false, %false2 : tensor<10xi1>, tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
// CHECK-LABEL: @fold_greater_eq_splat_f32
func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
@@ -352,6 +457,22 @@ func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
// -----
+// CHECK-LABEL: @fold_greater_eq_splat_ui8
+func.func @fold_greater_eq_splat_ui8() -> (tensor<10xi1>, tensor<10xi1>) {
+ %0 = "tosa.const"() {values = dense<1> : tensor<10xui8>} : () -> tensor<10xui8>
+ %1 = "tosa.const"() {values = dense<255> : tensor<10xui8>} : () -> tensor<10xui8>
+ %2 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+ %3 = "tosa.const"() {values = dense<245> : tensor<10xui8>} : () -> tensor<10xui8>
+ %true = tosa.greater_equal %2, %3 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ %false = tosa.greater_equal %0, %1 : (tensor<10xui8>, tensor<10xui8>) -> tensor<10xi1>
+ // CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() <{values = dense<true> : tensor<10xi1>}
+ // CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() <{values = dense<false> : tensor<10xi1>}
+ // CHECK: return %[[TRUE]], %[[FALSE]]
+ return %true, %false : tensor<10xi1>, tensor<10xi1>
+}
+
+// -----
+
// CHECK-LABEL: @fold_eq_splat_f32
func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {values = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
More information about the Mlir-commits
mailing list