[Mlir-commits] [mlir] [mlir][arith] Fix overflow bug in arith::CeilDivSIOp::fold (PR #90947)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri May 3 01:10:44 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/90947

The folder for arith::CeilDivSIOp should only be applied when it can be
guaranteed that no overflow would happen. The current implementation
works fine when both dividends are positive and the only arithmetic
operation is the division itself.

However, in cases where at least one of the dividends is negative, the
division is split into multiple operations, e.g.: `- ( -a / b)`. That's
additional 2 operations on top of the actual division that can overflow
- the folder should check all 3 ops for overflow. The current logic
doesn't do that - it effectively only the last operation (i.e. the
division). It breaks when using e.g. MININT values (e.g. -128 for
8-bit integers) - negating such values overflows.

This PR makes sure that no folding happens if any of the intermediate
arithmetic operations overflows.


>From c6480bcae235efb91868403aa9577d515aeb3335 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 3 May 2024 07:42:57 +0000
Subject: [PATCH] [mlir][arith] Fix overflow bug in arith::CeilDivSIOp::fold

The folder for arith::CeilDivSIOp should only be applied when it can be
guaranteed that no overflow would happen. The current implementation
works fine when both dividends are positive and the only arithmetic
operation is the division itself.

However, in cases where at least one of the dividends is negative, the
division is split into multiple operations, e.g.: `- ( -a / b)`. That's
additional 2 operations on top of the actual division that can overflow
- the folder should check all 3 ops for overflow. The current logic
doesn't do that - it effectively only the last operation (i.e. the
division). It breaks when using e.g. MININT values (e.g. -128 for
8-bit integers) - negating such values overflows.

This PR makes sure that no folding happens if any of the intermediate
arithmetic operations overflows.
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp  | 31 ++++++++++++++++++-------
 mlir/test/Transforms/constant-fold.mlir | 20 ++++++++++++++++
 2 files changed, 42 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 6f995b93bc3ecd..a89634dfe07a2f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -701,22 +701,35 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
           // Both positive, return ceil(a, b).
           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
         }
+
+        bool overflowNegA = false;
+        bool overflowNegB = false;
+        bool overflowNegDiv = false;
+        bool overflowDiv = false;
         if (!aGtZero && !bGtZero) {
           // Both negative, return ceil(-a, -b).
-          APInt posA = zero.ssub_ov(a, overflowOrDiv0);
-          APInt posB = zero.ssub_ov(b, overflowOrDiv0);
-          return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
+          APInt posA = zero.ssub_ov(a, overflowNegA);
+          APInt posB = zero.ssub_ov(b, overflowNegB);
+          APInt res = signedCeilNonnegInputs(posA, posB, overflowDiv);
+          overflowOrDiv0 =
+            (overflowNegA || overflowNegB || overflowDiv);
+          return res;
         }
         if (!aGtZero && bGtZero) {
           // A is negative, b is positive, return - ( -a / b).
-          APInt posA = zero.ssub_ov(a, overflowOrDiv0);
-          APInt div = posA.sdiv_ov(b, overflowOrDiv0);
-          return zero.ssub_ov(div, overflowOrDiv0);
+          APInt posA = zero.ssub_ov(a, overflowNegA);
+          APInt div = posA.sdiv_ov(b, overflowDiv);
+          APInt res = zero.ssub_ov(div, overflowNegDiv);
+          overflowOrDiv0 = (overflowNegA || overflowDiv || overflowNegDiv);
+          return res;
         }
         // A is positive, b is negative, return - (a / -b).
-        APInt posB = zero.ssub_ov(b, overflowOrDiv0);
-        APInt div = a.sdiv_ov(posB, overflowOrDiv0);
-        return zero.ssub_ov(div, overflowOrDiv0);
+        APInt posB = zero.ssub_ov(b, overflowNegB);
+        APInt div = a.sdiv_ov(posB, overflowDiv);
+        APInt res = zero.ssub_ov(div, overflowNegDiv);
+
+        overflowOrDiv0 = (overflowNegB || overflowDiv || overflowNegDiv);
+        return res;
       });
 
   return overflowOrDiv0 ? Attribute() : result;
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 253163f2af9110..8507342f59297f 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -478,6 +478,26 @@ func.func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
 
 // -----
 
+// CHECK-LABEL: func @simple_arith.ceildivsi_overflow
+func.func @simple_arith.ceildivsi_overflow() -> (i8, i16, i32) {
+  // CHECK-COUNT-3:  arith.ceildivsi
+  %0 = arith.constant 7 : i8
+  %1 = arith.constant -128 : i8
+  %2 = arith.ceildivsi %1, %0 : i8
+
+  %3 = arith.constant 7 : i16
+  %4 = arith.constant -32768 : i16
+  %5 = arith.ceildivsi %4, %3 : i16
+
+  %6 = arith.constant 7 : i32
+  %7 = arith.constant -2147483648 : i32
+  %8 = arith.ceildivsi %7, %6 : i32
+
+  return %2, %5, %8 : i8, i16, i32
+}
+
+// -----
+
 // CHECK-LABEL: func @simple_arith.ceildivui
 func.func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
   // CHECK-DAG: [[C0:%.+]] = arith.constant 0



More information about the Mlir-commits mailing list