[Mlir-commits] [mlir] [mlir][tosa] Check for overflow in binary integer folders (PR #172695)

Ian Tayler Lessa llvmlistbot at llvm.org
Mon Jan 19 03:24:32 PST 2026


================
@@ -889,33 +889,101 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
 // Operator Folders.
 //===----------------------------------------------------------------------===//
 
-template <typename IntFolder, typename FloatFolder>
+template <typename Folder>
 static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
                                       DenseElementsAttr rhs,
                                       RankedTensorType returnTy) {
   if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
-    auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
-    auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+    const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+    const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
     if (lETy != rETy)
       return {};
 
-    if (llvm::isa<IntegerType>(lETy)) {
-      APInt l = lhs.getSplatValue<APInt>();
-      APInt r = rhs.getSplatValue<APInt>();
-      auto result = IntFolder()(l, r);
-      return DenseElementsAttr::get(returnTy, result);
+    if (const auto lIntTy = dyn_cast<IntegerType>(lETy)) {
+      const APInt l = lhs.getSplatValue<APInt>();
+      const APInt r = rhs.getSplatValue<APInt>();
+      const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
 
     if (llvm::isa<FloatType>(lETy)) {
-      APFloat l = lhs.getSplatValue<APFloat>();
-      APFloat r = rhs.getSplatValue<APFloat>();
-      auto result = FloatFolder()(l, r);
-      return DenseElementsAttr::get(returnTy, result);
+      const APFloat l = lhs.getSplatValue<APFloat>();
+      const APFloat r = rhs.getSplatValue<APFloat>();
+      const auto maybeResult = Folder::fold(l, r);
+      if (failed(maybeResult))
+        return {};
+      return DenseElementsAttr::get(returnTy, maybeResult.value());
     }
   }
 
   return {};
 }
+struct AddFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    bool overflow;
+    const APInt result =
+        isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
+    if (overflow)
+      return failure();
+    return result;
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs + rhs;
+  }
+};
+
+struct SubFoldAdaptor {
+  static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+                               const bool isUnsigned) {
+    bool overflow;
+    const APInt result =
+        isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
+    if (overflow)
+      return failure();
+    return result;
+  }
+
+  static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+    return lhs - rhs;
+  }
+};
+
+struct FoldGreaterAdaptor {
----------------
IanTaylerLessa-arm wrote:

Nit: we're using `AddFoldAdaptor` but `FoldGreaterAdaptor`. It would be nice to keep all of them either as `<OP>FoldAdaptor` or `Fold<OP>Adaptor`

https://github.com/llvm/llvm-project/pull/172695


More information about the Mlir-commits mailing list