[Mlir-commits] [mlir] [mlir][arith] fix wrong floordivsi fold (#83079) (PR #83248)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 15 04:11:12 PDT 2024


https://github.com/lipracer updated https://github.com/llvm/llvm-project/pull/83248

>From fcaa9e5007ac52e13144f53edece638cb586a32a Mon Sep 17 00:00:00 2001
From: lipracer <lipracer at gmail.com>
Date: Mon, 11 Mar 2024 13:12:16 +0800
Subject: [PATCH] [mlir][arith] fix wrong floordivsi fold (#83079)

1) fix floordivsi error expand logic
2) fix floordivsi fold did't check overflow stat

Fixs https://github.com/llvm/llvm-project/issues/83079
---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 36 ++------
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 58 ++++++-------
 mlir/test/Dialect/Arith/expand-ops.mlir       | 84 ++++++++-----------
 .../Standard/CPU/test-ceil-floor-pos-neg.mlir | 21 +++++
 mlir/test/Transforms/canonicalize.mlir        |  9 ++
 5 files changed, 93 insertions(+), 115 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 9f64a07f31e3af..2f32d9a26e7752 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -689,43 +689,17 @@ OpFoldResult arith::FloorDivSIOp::fold(FoldAdaptor adaptor) {
     return getLhs();
 
   // Don't fold if it would overflow or if it requires a division by zero.
-  bool overflowOrDiv0 = false;
+  bool overflowOrDiv = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](APInt a, const APInt &b) {
-        if (overflowOrDiv0 || !b) {
-          overflowOrDiv0 = true;
+        if (b.isZero()) {
+          overflowOrDiv = true;
           return a;
         }
-        if (!a)
-          return a;
-        // After this point we know that neither a or b are zero.
-        unsigned bits = a.getBitWidth();
-        APInt zero = APInt::getZero(bits);
-        bool aGtZero = a.sgt(zero);
-        bool bGtZero = b.sgt(zero);
-        if (aGtZero && bGtZero) {
-          // Both positive, return a / b.
-          return a.sdiv_ov(b, overflowOrDiv0);
-        }
-        if (!aGtZero && !bGtZero) {
-          // Both negative, return -a / -b.
-          APInt posA = zero.ssub_ov(a, overflowOrDiv0);
-          APInt posB = zero.ssub_ov(b, overflowOrDiv0);
-          return posA.sdiv_ov(posB, overflowOrDiv0);
-        }
-        if (!aGtZero && bGtZero) {
-          // A is negative, b is positive, return - ceil(-a, b).
-          APInt posA = zero.ssub_ov(a, overflowOrDiv0);
-          APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
-          return zero.ssub_ov(ceil, overflowOrDiv0);
-        }
-        // A is positive, b is negative, return - ceil(a, -b).
-        APInt posB = zero.ssub_ov(b, overflowOrDiv0);
-        APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
-        return zero.ssub_ov(ceil, overflowOrDiv0);
+        return a.sfloordiv_ov(b, overflowOrDiv);
       });
 
-  return overflowOrDiv0 ? Attribute() : result;
+  return overflowOrDiv ? Attribute() : result;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 7f246daf99ff3c..71e14a153cfda9 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -110,9 +110,13 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   }
 };
 
-/// Expands FloorDivSIOp (n, m) into
-///   1)  x = (m<0) ? 1 : -1
-///   2)  return (n*m<0) ? - ((-n+x) / m) -1 : n / m
+/// Expands FloorDivSIOp (x, y) into
+/// z = x / y
+/// if (z * y != x && (x < 0) != (y < 0)) {
+///   return  z - 1;
+/// } else {
+///   return z;
+/// }
 struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
@@ -121,41 +125,29 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
     Type type = op.getType();
     Value a = op.getLhs();
     Value b = op.getRhs();
-    Value plusOne = createConst(loc, type, 1, rewriter);
+
+    Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
+    Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
+    Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne, a, product);
     Value zero = createConst(loc, type, 0, rewriter);
-    Value minusOne = createConst(loc, type, -1, rewriter);
-    // Compute x = (b<0) ? 1 : -1.
-    Value compare =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
-    Value x = rewriter.create<arith::SelectOp>(loc, compare, plusOne, minusOne);
-    // Compute negative res: -1 - ((x-a)/b).
-    Value xMinusA = rewriter.create<arith::SubIOp>(loc, x, a);
-    Value xMinusADivB = rewriter.create<arith::DivSIOp>(loc, xMinusA, b);
-    Value negRes = rewriter.create<arith::SubIOp>(loc, minusOne, xMinusADivB);
-    // Compute positive res: a/b.
-    Value posRes = rewriter.create<arith::DivSIOp>(loc, a, b);
-    // Result is (a*b<0) ? negative result : positive result.
-    // Note, we want to avoid using a*b because of possible overflow.
-    // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
-    // not particuliarly care if a*b<0 is true or false when b is zero
-    // as this will result in an illegal divide. So `a*b<0` can be reformulated
-    // as `(a>0 && b<0) || (a>0 && b<0)' or `(a>0 && b<0) || (a>0 && b<=0)'.
-    // We pick the first expression here.
+
     Value aNeg =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
-    Value aPos =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
     Value bNeg =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
-    Value bPos =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
-    Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bPos);
-    Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
-    Value compareRes =
-        rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    // Perform substitution and return success.
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, negRes,
-                                                 posRes);
+
+    Value signOpposite = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne, aNeg, bNeg);
+    Value cond =
+        rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite);
+
+    Value minusOne = createConst(loc, type, -1, rewriter);
+    Value quotientMinusOne =
+        rewriter.create<arith::AddIOp>(loc, quotient, minusOne);
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne,
+                                                 quotient);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 91f652e5a270e3..6bed93e4c969db 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -66,23 +66,17 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
 func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
   %res = arith.floordivsi %arg0, %arg1 : i32
   return %res : i32
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK:           [[MIN1:%.+]] = arith.constant -1 : i32
-// CHECK:           [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : i32
-// CHECK:           [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : i32
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK:           [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : i32
-// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
+// CHECK:   %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : i32
+// CHECK:   %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : i32
+// CHECK:   %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : i32
+// CHECK-DAG:   %[[ZERO:.*]] = arith.constant 0 : i32
+// CHECK:   %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : i32
+// CHECK:   %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : i32
+// CHECK:   %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
+// CHECK:   %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
+// CHECK-DAG:   %[[NEG_ONE:.*]] = arith.constant -1 : i32
+// CHECK:   %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : i32
+// CHECK:   %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : i32
 }
 
 // -----
@@ -93,23 +87,17 @@ func.func @floordivi(%arg0: i32, %arg1: i32) -> (i32) {
 func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
   %res = arith.floordivsi %arg0, %arg1 : index
   return %res : index
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
-// CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK:           [[MIN1:%.+]] = arith.constant -1 : index
-// CHECK:           [[CMP1:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[ONE]], [[MIN1]] : index
-// CHECK:           [[TRUE1:%.+]] = arith.subi [[X]], [[ARG0]] : index
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK:           [[TRUE3:%.+]] = arith.subi [[MIN1]], [[TRUE2]] : index
-// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MPOS]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MNEG]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
+// CHECK:   %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : index
+// CHECK:   %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : index
+// CHECK:   %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : index
+// CHECK-DAG:   %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK:   %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : index
+// CHECK:   %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : index
+// CHECK:   %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : i1
+// CHECK:   %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : i1
+// CHECK:   %[[NEG_ONE:.*]] = arith.constant -1 : index
+// CHECK-DAG:   %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : index
+// CHECK:   %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : index
 }
 
 // -----
@@ -121,23 +109,17 @@ func.func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
 func.func @floordivi_vec(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>) {
   %res = arith.floordivsi %arg0, %arg1 : vector<4xi32>
   return %res : vector<4xi32>
-// CHECK:           %[[VAL_2:.*]] = arith.constant dense<1> : vector<4xi32>
-// CHECK:           %[[VAL_3:.*]] = arith.constant dense<0> : vector<4xi32>
-// CHECK:           %[[VAL_4:.*]] = arith.constant dense<-1> : vector<4xi32>
-// CHECK:           %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_6:.*]] = arith.select %[[VAL_5]], %[[VAL_2]], %[[VAL_4]] : vector<4xi1>, vector<4xi32>
-// CHECK:           %[[VAL_7:.*]] = arith.subi %[[VAL_6]], %[[VAL_0]] : vector<4xi32>
-// CHECK:           %[[VAL_8:.*]] = arith.divsi %[[VAL_7]], %[[VAL_1]] : vector<4xi32>
-// CHECK:           %[[VAL_9:.*]] = arith.subi %[[VAL_4]], %[[VAL_8]] : vector<4xi32>
-// CHECK:           %[[VAL_10:.*]] = arith.divsi %[[VAL_0]], %[[VAL_1]] : vector<4xi32>
-// CHECK:           %[[VAL_11:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_12:.*]] = arith.cmpi sgt, %[[VAL_0]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_13:.*]] = arith.cmpi slt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_14:.*]] = arith.cmpi sgt, %[[VAL_1]], %[[VAL_3]] : vector<4xi32>
-// CHECK:           %[[VAL_15:.*]] = arith.andi %[[VAL_11]], %[[VAL_14]] : vector<4xi1>
-// CHECK:           %[[VAL_16:.*]] = arith.andi %[[VAL_12]], %[[VAL_13]] : vector<4xi1>
-// CHECK:           %[[VAL_17:.*]] = arith.ori %[[VAL_15]], %[[VAL_16]] : vector<4xi1>
-// CHECK:           %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_9]], %[[VAL_10]] : vector<4xi1>, vector<4xi32>
+// CHECK:   %[[QUOTIENT:.*]] = arith.divsi %arg0, %arg1 : vector<4xi32>
+// CHECK:   %[[PRODUCT:.*]] = arith.muli %[[QUOTIENT]], %arg1 : vector<4xi32>
+// CHECK:   %[[NOT_EQ_PRODUCT:.*]] = arith.cmpi ne, %arg0, %[[PRODUCT]] : vector<4xi32>
+// CHECK-DAG:   %[[ZERO:.*]] = arith.constant dense<0> : vector<4xi32>
+// CHECK:   %[[NEG_DIVISOR:.*]] = arith.cmpi slt, %arg0, %[[ZERO]] : vector<4xi32>
+// CHECK:   %[[NEG_DIVIDEND:.*]] = arith.cmpi slt, %arg1, %[[ZERO]] : vector<4xi32>
+// CHECK:   %[[OPPOSITE_SIGN:.*]] = arith.cmpi ne, %[[NEG_DIVISOR]], %[[NEG_DIVIDEND]] : vector<4xi1>
+// CHECK:   %[[CONDITION:.*]] = arith.andi %[[NOT_EQ_PRODUCT]], %[[OPPOSITE_SIGN]] : vector<4xi1>
+// CHECK-DAG:   %[[NEG_ONE:.*]] = arith.constant dense<-1> : vector<4xi32>
+// CHECK:   %[[MINUS_ONE:.*]] = arith.addi %[[QUOTIENT]], %[[NEG_ONE]] : vector<4xi32>
+// CHECK:   %[[RES:.*]] = arith.select %[[CONDITION]], %[[MINUS_ONE]], %[[QUOTIENT]] : vector<4xi1>, vector<4xi32>
 }
 
 // -----
diff --git a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
index 39fbb67512c6f0..9f3672e56a48eb 100644
--- a/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
+++ b/mlir/test/Integration/Dialect/Standard/CPU/test-ceil-floor-pos-neg.mlir
@@ -2,6 +2,10 @@
 // RUN: mlir-cpu-runner -e entry -entry-point-result=void  \
 // RUN:   -shared-libs=%mlir_c_runner_utils | \
 // RUN: FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-scf,lower-affine,convert-scf-to-cf,memref-expand,arith-expand),convert-vector-to-llvm,finalize-memref-to-llvm,convert-func-to-llvm,reconcile-unrealized-casts)" | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void  \
+// RUN:   -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s --check-prefix=SCHECK
 
 func.func @transfer_read_2d(%A : memref<40xi32>, %base1: index) {
   %i42 = arith.constant -42: i32
@@ -101,3 +105,20 @@ func.func @entry() {
 // CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
 // CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
 // CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
+
+// -----
+
+func.func @non_inline_function() -> (i64, i64) {
+  %MIN_INT_MINUS_ONE = arith.constant -9223372036854775807 : i64
+  %NEG_ONE = arith.constant -1 : i64
+  return %MIN_INT_MINUS_ONE, %NEG_ONE : i64, i64
+}
+
+func.func @main() {
+  %0:2 = call @non_inline_function() : () -> (i64, i64)
+  %1 = arith.floordivsi %0#0, %0#1 : i64
+  vector.print %1 : i64
+  return
+}
+
+// SCHECK: 9223372036854775807
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 2cf86b50d432f6..d2c2c12d323892 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -989,6 +989,15 @@ func.func @tensor_arith.floordivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5x
   return %res : tensor<4x5xi32>
 }
 
+// CHECK-LABEL: func @arith.floordivsi_by_one_overflow
+func.func @arith.floordivsi_by_one_overflow() -> i64 {
+  %neg_one = arith.constant -1 : i64
+  %min_int = arith.constant -9223372036854775808 : i64
+  // CHECK: arith.floordivsi
+  %poision = arith.floordivsi %min_int, %neg_one : i64
+  return %poision : i64
+}
+
 // -----
 
 // CHECK-LABEL: func @arith.ceildivsi_by_one



More information about the Mlir-commits mailing list