[Mlir-commits] [mlir] [mlir][arith] fix ceildivsi lowering (PR #106696)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 5 09:35:03 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Joachim Giron (jgiron42)

<details>
<summary>Changes</summary>

This commit fix the overflow in the case of
ceildivsi( <signed_type_min>, <positive_integer> )

fixes #<!-- -->106519 


---
Full diff: https://github.com/llvm/llvm-project/pull/106696.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+3-5) 
- (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+4-8) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a710113..a37a18135386e1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -60,7 +60,7 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
 
 /// Expands CeilDivSIOp (n, m) into
 ///   1) x = (m > 0) ? -1 : 1
-///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+///   2) (n*m>0) ? ((n+x) / m) + 1 : n / m
 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -80,10 +80,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
-    // Compute negative res: - ((-a)/b).
-    Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
-    Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
-    Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
+    // Compute negative res: a/b.
+    Value negRes = rewriter.create<arith::DivSIOp>(loc, a, b);
     // Result is (a*b>0) ? pos result : neg result.
     // Note, we want to avoid using a*b because of possible overflow.
     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc0041..3ba89ab740ef18 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -15,9 +15,7 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
 // CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
 // CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
 // CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
+// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
 // CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
 // CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
 // CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
@@ -25,7 +23,7 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
 // CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
 // CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
 // CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
 }
 
 // -----
@@ -45,9 +43,7 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
 // CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
 // CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
 // CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
+// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
 // CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
 // CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
 // CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
@@ -55,7 +51,7 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
 // CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
 // CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
 // CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/106696


More information about the Mlir-commits mailing list