[Mlir-commits] [llvm] [mlir] [mlir][arith] fix wrong floordivsi fold (#83079) (PR #83248)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 9 23:14:16 PST 2024
https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/83248
>From 40faadb69f78a21f9eb1f082cdcb544d05c50f85 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Wed, 28 Feb 2024 18:48:40 +0800
Subject: [PATCH 1/4] [mlir][arith] fix wrong floordivsi fold (#83079)
Fixs https://github.com/llvm/llvm-project/issues/83079
---
llvm/include/llvm/ADT/APInt.h | 1 +
llvm/lib/Support/APInt.cpp | 8 ++++++
llvm/unittests/ADT/APIntTest.cpp | 34 ++++++++++++++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 36 +++-----------------------
mlir/test/Transforms/canonicalize.mlir | 9 +++++++
5 files changed, 55 insertions(+), 33 deletions(-)
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1fc3c7b2236a17..8a085f8b05ebbd 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -996,6 +996,7 @@ class [[nodiscard]] APInt {
APInt sshl_ov(unsigned Amt, bool &Overflow) const;
APInt ushl_ov(const APInt &Amt, bool &Overflow) const;
APInt ushl_ov(unsigned Amt, bool &Overflow) const;
+ APInt sfloordiv_ov(const APInt &RHS, bool &Overflow) const;
// Operations that saturate
APInt sadd_sat(const APInt &RHS) const;
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index e686b976523302..3bff2856cbdc54 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -2022,6 +2022,14 @@ APInt APInt::ushl_ov(unsigned ShAmt, bool &Overflow) const {
return *this << ShAmt;
}
+APInt APInt::sfloordiv_ov(const APInt &RHS, bool &Overflow) const {
+ auto quotient = sdiv_ov(RHS, Overflow);
+ if ((quotient * RHS != *this) && (isNegative() != RHS.isNegative()))
+ return quotient - 1;
+ else
+ return quotient;
+}
+
APInt APInt::sadd_sat(const APInt &RHS) const {
bool Overflow;
APInt Res = sadd_ov(RHS, Overflow);
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 24324822356bf6..2af9afc8b30946 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -14,6 +14,7 @@
#include "llvm/Support/Alignment.h"
#include "gtest/gtest.h"
#include <array>
+#include <limits>
#include <optional>
using namespace llvm;
@@ -2928,6 +2929,39 @@ TEST(APIntTest, smul_ov) {
}
}
+TEST(APIntTest, sfloordiv_ov) {
+ {
+ APInt divisor(32, -3, true);
+ APInt dividend(32, 2, true);
+ bool Overflow = false;
+ auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+ EXPECT_FALSE(Overflow);
+ EXPECT_EQ(-2, quotient.getSExtValue());
+ }
+ {
+ APInt divisor(32, std::numeric_limits<int>::lowest(), true);
+ APInt dividend(32, -1, true);
+ bool Overflow = false;
+ [[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+ EXPECT_TRUE(Overflow);
+ }
+ {
+ auto check_overflow_one = [](auto arg) {
+ using IntTy = decltype(arg);
+ APInt divisor(8 * sizeof(arg), std::numeric_limits<IntTy>::lowest(),
+ true);
+ APInt dividend(8 * sizeof(arg), IntTy(-1), true);
+ bool Overflow = false;
+ [[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+ EXPECT_TRUE(Overflow);
+ };
+ auto check_overflow_all = [&](auto... args) {
+ (void)std::initializer_list<int>{(check_overflow_one(args), 0)...};
+ };
+ std::apply(check_overflow_all, std::tuple<char, short, int, int64_t>());
+ }
+}
+
TEST(APIntTest, SolveQuadraticEquationWrap) {
// Verify that "Solution" is the first non-negative integer that solves
// Ax^2 + Bx + C = "0 or overflow", i.e. that it is a correct solution
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 0f71c19c23b654..55126d3c5aa311 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -689,43 +689,13 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
return getLhs();
// Don't fold if it would overflow or if it requires a division by zero.
- bool overflowOrDiv0 = false;
+ bool overflowOrDiv = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](APInt a, const APInt &b) {
- if (overflowOrDiv0 || !b) {
- overflowOrDiv0 = true;
- return a;
- }
- if (!a)
- return a;
- // After this point we know that neither a or b are zero.
- unsigned bits = a.getBitWidth();
- APInt zero = APInt::getZero(bits);
- bool aGtZero = a.sgt(zero);
- bool bGtZero = b.sgt(zero);
- if (aGtZero && bGtZero) {
- // Both positive, return a / b.
- return a.sdiv_ov(b, overflowOrDiv0);
- }
- if (!aGtZero && !bGtZero) {
- // Both negative, return -a / -b.
- APInt posA = zero.ssub_ov(a, overflowOrDiv0);
- APInt posB = zero.ssub_ov(b, overflowOrDiv0);
- return posA.sdiv_ov(posB, overflowOrDiv0);
- }
- if (!aGtZero && bGtZero) {
- // A is negative, b is positive, return - ceil(-a, b).
- APInt posA = zero.ssub_ov(a, overflowOrDiv0);
- APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
- return zero.ssub_ov(ceil, overflowOrDiv0);
- }
- // A is positive, b is negative, return - ceil(a, -b).
- APInt posB = zero.ssub_ov(b, overflowOrDiv0);
- APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
- return zero.ssub_ov(ceil, overflowOrDiv0);
+ return a.sfloordiv_ov(b, overflowOrDiv);
});
- return overflowOrDiv0 ? Attribute() : result;
+ return overflowOrDiv ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 2cf86b50d432f6..d2c2c12d323892 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -989,6 +989,15 @@ func.func @tensor_arith.floordivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5x
return %res : tensor<4x5xi32>
}
+// CHECK-LABEL: func @arith.floordivsi_by_one_overflow
+func.func @arith.floordivsi_by_one_overflow() -> i64 {
+ %neg_one = arith.constant -1 : i64
+ %min_int = arith.constant -9223372036854775808 : i64
+ // CHECK: arith.floordivsi
+ %poision = arith.floordivsi %min_int, %neg_one : i64
+ return %poision : i64
+}
+
// -----
// CHECK-LABEL: func @arith.ceildivsi_by_one
>From 52ff01d8f0b559215ee5ee6970f3fbb9016de76f Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 9 Mar 2024 21:40:47 +0800
Subject: [PATCH 2/4] fix floordivi expand error logic
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 +
.../Dialect/Arith/Transforms/ExpandOps.cpp | 45 +++++++++-
mlir/test/Dialect/Arith/expand-ops.mlir | 84 ++++++++-----------
3 files changed, 81 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 55126d3c5aa311..fbc169bd6771b1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -692,6 +692,10 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
bool overflowOrDiv = false;
auto result = constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [&](APInt a, const APInt &b) {
+ if (b.isZero()) {
+ overflowOrDiv = true;
+ return a;
+ }
return a.sfloordiv_ov(b, overflowOrDiv);
});
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7f246daf99ff3c..7e8540f642fc83 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -110,10 +110,53 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
}
};
+/// Expands FloorDivSIOp (x, y) into
+/// z = x / y
+/// if (z * y != x && (x < 0) != (y < 0)) {
+/// return z - 1;
+/// } else {
+/// return z;
+/// }
+struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
+ PatternRewriter &rewriter) const final {
+ Location loc = op.getLoc();
+ Type type = op.getType();
+ Value a = op.getLhs();
+ Value b = op.getRhs();
+
+ Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
+ Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
+ Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, a, product);
+ Value zero = createConst(loc, type, 0, rewriter);
+
+ Value aNeg =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
+ Value bNeg =
+ rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
+
+ Value signOpposite = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, aNeg, bNeg);
+ Value cond =
+ rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
+
+ Value minusOne = createConst(loc, type, -1, rewriter);
+ Value quotientMinusOne =
+ rewriter.create<arith::SubIOp>(loc, quotient, minusOne);
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
+ quotient);
+ return success();
+ }
+};
+
/// Expands FloorDivSIOp (n, m) into
/// 1) x = (m<0) ? 1 : -1
/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
-struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
+struct AggressiveFloorDivSIOpConverter
+ : public OpRewritePattern<arith::FloorDivSIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
PatternRewriter &rewriter) const final {
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 91f652e5a270e3..04420ae0a33fc6 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -66,23 +66,17 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
%res = arith.floordivsi %arg0, %arg1 : i32
return %res : i32
-// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
-// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK: [[MIN1:%.+]] = arith.constant -1 : i32
-// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32
-// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32
-// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32
-// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
-// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
-// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
-// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
+// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : i32
+// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : i32
+// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : i32
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32
+// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : i32
+// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : i32
+// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
+// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
+// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant -1 : i32
+// CHECK: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : i32
+// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32
}
// -----
@@ -93,23 +87,17 @@ func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
%res = arith.floordivsi %arg0, %arg1 : index
return %res : index
-// CHECK: [[ONE:%.+]] = arith.constant 1 : index
-// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK: [[MIN1:%.+]] = arith.constant -1 : index
-// CHECK: [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index
-// CHECK: [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
-// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK: [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
-// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
-// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
-// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
-// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
+// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : index
+// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : index
+// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : index
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : index
+// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : index
+// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
+// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
+// CHECK: %[[NEG_ONE:.*]] = arith.constant -1 : index
+// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : index
+// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index
}
// -----
@@ -121,23 +109,17 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
%res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
return %res : vector<4xi32>
-// CHECK: %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32>
-// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32>
-// CHECK: %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32>
-// CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK: %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32>
-// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32>
-// CHECK: %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32>
-// CHECK: %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32>
-// CHECK: %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32>
-// CHECK: %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
-// CHECK: %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
-// CHECK: %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK: %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1>
-// CHECK: %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1>
-// CHECK: %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1>
-// CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32>
+// CHECK: %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : vector<4xi32>
+// CHECK: %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : vector<4xi32>
+// CHECK: %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : vector<4xi32>
+// CHECK-DAG: %[[ZERO:.*]] = arith.constant dense<0> : vector<4xi32>
+// CHECK: %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : vector<4xi32>
+// CHECK: %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : vector<4xi32>
+// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1>
+// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1>
+// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32>
+// CHECK: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
+// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32>
}
// -----
>From 5dcd8892bbc298b0f1584acded0326eb9e047986 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sat, 9 Mar 2024 22:58:25 +0800
Subject: [PATCH 3/4] refine test
---
llvm/unittests/ADT/APIntTest.cpp | 12 +++--
.../Dialect/Arith/Transforms/ExpandOps.cpp | 51 -------------------
2 files changed, 8 insertions(+), 55 deletions(-)
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 2af9afc8b30946..5485978934ed07 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2930,6 +2930,7 @@ TEST(APIntTest, smul_ov) {
}
TEST(APIntTest, sfloordiv_ov) {
+ // test negative quotient
{
APInt divisor(32, -3, true);
APInt dividend(32, 2, true);
@@ -2938,13 +2939,16 @@ TEST(APIntTest, sfloordiv_ov) {
EXPECT_FALSE(Overflow);
EXPECT_EQ(-2, quotient.getSExtValue());
}
+ // test positive quotient
{
- APInt divisor(32, std::numeric_limits<int>::lowest(), true);
- APInt dividend(32, -1, true);
+ APInt divisor(32, 3, true);
+ APInt dividend(32, 2, true);
bool Overflow = false;
- [[maybe_unused]] auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
- EXPECT_TRUE(Overflow);
+ auto quotient = divisor.sfloordiv_ov(dividend, Overflow);
+ EXPECT_FALSE(Overflow);
+ EXPECT_EQ(1, quotient.getSExtValue());
}
+ // test overflow
{
auto check_overflow_one = [](auto arg) {
using IntTy = decltype(arg);
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7e8540f642fc83..1996374f0edae3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -152,57 +152,6 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
}
};
-/// Expands FloorDivSIOp (n, m) into
-/// 1) x = (m<0) ? 1 : -1
-/// 2) return (n*m<0) ? - ((-n+x) / m) -1 : n / m
-struct AggressiveFloorDivSIOpConverter
- : public OpRewritePattern<arith::FloorDivSIOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
- PatternRewriter &rewriter) const final {
- Location loc = op.getLoc();
- Type type = op.getType();
- Value a = op.getLhs();
- Value b = op.getRhs();
- Value plusOne = createConst(loc, type, 1, rewriter);
- Value zero = createConst(loc, type, 0, rewriter);
- Value minusOne = createConst(loc, type, -1, rewriter);
- // Compute x = (b<0) ? 1 : -1.
- Value compare =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
- Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
- // Compute negative res: -1 - ((x-a)/b).
- Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
- Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
- Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
- // Compute positive res: a/b.
- Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
- // Result is (a*b<0) ? negative result : positive result.
- // Note, we want to avoid using a*b because of possible overflow.
- // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
- // not particuliarly care if a*b<0 is true or false when b is zero
- // as this will result in an illegal divide. So `a*b<0` can be reformulated
- // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
- // We pick the first expression here.
- Value aNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value aPos =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
- Value bNeg =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
- Value bPos =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
- Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
- Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
- Value compareRes =
- rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
- // Perform substitution and return success.
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
- posRes);
- return success();
- }
-};
-
template <typename OpTy, arith::CmpFPredicate pred>
struct MaximumMinimumFOpConverter : public OpRewritePattern<OpTy> {
public:
>From f231e71dc14837ca9b3444135276379a6c566fa8 Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Sun, 10 Mar 2024 15:18:41 +0800
Subject: [PATCH 4/4] add integrate test
---
.../Dialect/Arith/Transforms/ExpandOps.cpp | 2 +-
mlir/test/Dialect/Arith/expand-ops.mlir | 6 +++---
.../Standard/CPU/test-ceil-floor-pos-neg.mlir | 21 +++++++++++++++++++
3 files changed, 25 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 1996374f0edae3..71e14a153cfda9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -144,7 +144,7 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
Value minusOne = createConst(loc, type, -1, rewriter);
Value quotientMinusOne =
- rewriter.create<arith::SubIOp>(loc, quotient, minusOne);
+ rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
quotient);
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 04420ae0a33fc6..6bed93e4c969db 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -75,7 +75,7 @@ func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant -1 : i32
-// CHECK: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : i32
+// CHECK: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : i32
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32
}
@@ -96,7 +96,7 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
// CHECK: %[[NEG_ONE:.*]] = arith.constant -1 : index
-// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : index
+// CHECK-DAG: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : index
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index
}
@@ -118,7 +118,7 @@ func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<
// CHECK: %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1>
// CHECK: %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1>
// CHECK-DAG: %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32>
-// CHECK: %[[MINUS_ONE:.*]] = arith.subi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
+// CHECK: %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
// CHECK: %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32>
}
diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
index 39fbb67512c6f0..9f3672e56a48eb 100644
--- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
+++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -2,6 +2,10 @@
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_c_runner_utils | \
// RUN: FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,lower-affine,convert-scf-to-cf,memref-expand,arith-expand),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --check-prefix=SCHECK
func.func @transfer_read_2d(%A : memref<40xi32>, %base1: index) {
%i42 = arith.constant -42: i32
@@ -101,3 +105,20 @@ func.func @entry() {
// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+
+// -----
+
+func.func @non_inline_function() -> (i64, i64) {
+ %MIN_INT_MINUS_ONE = arith.constant -9223372036854775807 : i64
+ %NEG_ONE = arith.constant -1 : i64
+ return %MIN_INT_MINUS_ONE, %NEG_ONE : i64, i64
+}
+
+func.func @main() {
+ %0:2 = call @non_inline_function() : () -> (i64, i64)
+ %1 = arith.floordivsi %0#0, %0#1 : i64
+ vector.print %1 : i64
+ return
+}
+
+// SCHECK: 9223372036854775807
More information about the Mlir-commits
mailing list