[llvm] [mlir] [mlir][arith] fix wrong floordivsi fold (#83079) (PR #83248)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 8 23:54:47 PST 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/83248

>From 1fff7f9ebb8032ec7899c620342e5239ae9d74f0 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Wed, 28 Feb 2024 18:48:40 +0800
Subject: [PATCH] [mlir][arith] fix wrong floordivsi fold (#83079)

Fixs https://github.com/llvm/llvm-project/issues/83079
---
 llvm/include/llvm/ADT/APInt.h          |  1 +
 llvm/lib/Support/APInt.cpp             |  8 ++++++
 llvm/unittests/ADT/APIntTest.cpp       | 34 ++++++++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 36 +++-----------------------
 mlir/test/Transforms/canonicalize.mlir |  9 +++++++
 5 files changed, 55 insertions(+), 33 deletions(-)

diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 6c05367cecb1ea..9e6604a3dd8ea5 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -996,6 +996,7 @@ class [[nodiscard]] APInt {
   APInt sshl_ov(unsigned Amt, bool &Overflow) const;
   APInt ushl_ov(const APInt &Amt, bool &Overflow) const;
   APInt ushl_ov(unsigned Amt, bool &Overflow) const;
+  APInt sfloordiv_ov(const APInt &RHS, bool &Overflow) const;
 
   // Operations that saturate
   APInt sadd_sat(const APInt &RHS) const;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 05b1526da95ff7..97e0c5d09709e4 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -2022,6 +2022,14 @@ APInt APInt::ushl_ov(unsigned ShAmt, bool &Overflow) const {
   return *this << ShAmt;
 }
 
+APInt APInt::sfloordiv_ov(const APInt &RHS, bool &Overflow) const {
+  auto quotient = sdiv_ov(RHS, Overflow);
+  if ((quotient * RHS != *this) && (isNegative() != RHS.isNegative()))
+    return quotient - 1;
+  else
+    return quotient;
+}
+
 APInt APInt::sadd_sat(const APInt &RHS) const {
   bool Overflow;
   APInt Res = sadd_ov(RHS, Overflow);
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 2fe59f05ca753a..8c10028c3ef629 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -14,6 +14,7 @@
 #include "llvm/Support/Alignment.h"
 #include "gtest/gtest.h"
 #include <array>
+#include <limits>
 #include <optional>
 
 using namespace llvm;
@@ -2893,6 +2894,39 @@ TEST(APIntTest, smul_ov) {
       }
 }
 
+TEST(APIntTest, sfloordiv_ov) {
+  {
+    APInt divisor(32, -3, true);
+    APInt dividend(32, 2, true);
+    bool Overflow = false;
+    auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+    EXPECT_FALSE(Overflow);
+    EXPECT_EQ(-2, quotient.getSExtValue());
+  }
+  {
+    APInt divisor(32, std::numeric_limits<int>::lowest(), true);
+    APInt dividend(32, -1, true);
+    bool Overflow = false;
+    [[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+    EXPECT_TRUE(Overflow);
+  }
+  {
+    auto check_overflow_one = [](auto arg) {
+      using IntTy = decltype(arg);
+      APInt divisor(8 * sizeof(arg), std::numeric_limits<IntTy>::lowest(),
+                    true);
+      APInt dividend(8 * sizeof(arg), IntTy(-1), true);
+      bool Overflow = false;
+      [[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+      EXPECT_TRUE(Overflow);
+    };
+    auto check_overflow_all = [&](auto... args) {
+      (void)std::initializer_list<int>{(check_overflow_one(args), 0)...};
+    };
+    std::apply(check_overflow_all, std::tuple<char, short, int, int64_t>());
+  }
+}
+
 TEST(APIntTest, SolveQuadraticEquationWrap) {
   // Verify that "Solution" is the first non-negative integer that solves
   // Ax^2 + Bx + C = "0 or overflow", i.e. that it is a correct solution
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..55126d3c5aa311 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -689,43 +689,13 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
     return getLhs();
 
   // Don't fold if it would overflow or if it requires a division by zero.
-  bool overflowOrDiv0 = false;
+  bool overflowOrDiv = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](APInt a, const APInt &b) {
-        if (overflowOrDiv0 || !b) {
-          overflowOrDiv0 = true;
-          return a;
-        }
-        if (!a)
-          return a;
-        // After this point we know that neither a or b are zero.
-        unsigned bits = a.getBitWidth();
-        APInt zero = APInt::getZero(bits);
-        bool aGtZero = a.sgt(zero);
-        bool bGtZero = b.sgt(zero);
-        if (aGtZero && bGtZero) {
-          // Both positive, return a / b.
-          return a.sdiv_ov(b, overflowOrDiv0);
-        }
-        if (!aGtZero && !bGtZero) {
-          // Both negative, return -a / -b.
-          APInt posA = zero.ssub_ov(a, overflowOrDiv0);
-          APInt posB = zero.ssub_ov(b, overflowOrDiv0);
-          return posA.sdiv_ov(posB, overflowOrDiv0);
-        }
-        if (!aGtZero && bGtZero) {
-          // A is negative, b is positive, return - ceil(-a, b).
-          APInt posA = zero.ssub_ov(a, overflowOrDiv0);
-          APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
-          return zero.ssub_ov(ceil, overflowOrDiv0);
-        }
-        // A is positive, b is negative, return - ceil(a, -b).
-        APInt posB = zero.ssub_ov(b, overflowOrDiv0);
-        APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
-        return zero.ssub_ov(ceil, overflowOrDiv0);
+        return a.sfloordiv_ov(b, overflowOrDiv);
       });
 
-  return overflowOrDiv0 ? Attribute() : result;
+  return overflowOrDiv ? Attribute() : result;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 2cf86b50d432f6..d2c2c12d323892 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -989,6 +989,15 @@ func.func @tensor_arith.floordivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5x
   return %res : tensor<4x5xi32>
 }
 
+// CHECK-LABEL: func @arith.floordivsi_by_one_overflow
+func.func @arith.floordivsi_by_one_overflow() -> i64 {
+  %neg_one = arith.constant -1 : i64
+  %min_int = arith.constant -9223372036854775808 : i64
+  // CHECK: arith.floordivsi
+  %poision = arith.floordivsi %min_int, %neg_one : i64
+  return %poision : i64
+}
+
 // -----
 
 // CHECK-LABEL: func @arith.ceildivsi_by_one



More information about the llvm-commits mailing list