[Mlir-commits] [mlir] [mlir] [arith] Fix ceildivsi lowering in arith-expand (PR #133774)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 31 11:44:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Fehr Mathieu (math-fehr)

<details>
<summary>Changes</summary>

This fixes the current lowering of `arith.ceildivsi` in the arith-expand pass, which was previously incorrect. The new version is based on the lowering of `arith.floordivsi`, and will not introduce new undefined behavior or poison during the lowering. It also replaces one division with a multiplication.

The previous lowering of `ceildivsi(n, m)` was the following:
```
x = (m > 0) ? -1 : 1
(n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
```

This caused two problems:
* In the case where `n` is INT_MIN and `m` is positive, the result would be poison instead of an actual value
* In the case where `n` is INT_MAX and `m` is `-1`, this would trigger undefined behavior, while the original code wouldn't. This is because `n+x` would be equal to `INT_MIN` (`INT_MAX + 1`), so the `(n+x) / m` division would overflow and trigger UB.

---
Full diff: https://github.com/llvm/llvm-project/pull/133774.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+25-35) 
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+14-22) 
- (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+21-36) 


``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a71011..2d627e523cde5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -58,9 +58,13 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
   }
 };
 
-/// Expands CeilDivSIOp (n, m) into
-///   1) x = (m > 0) ? -1 : 1
-///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+/// Expands CeilDivSIOp (a, b) into
+/// z = a / b
+/// if (z * b != a && (a < 0) == (b < 0)) {
+///   return z + 1;
+/// } else {
+///   return z;
+/// }
 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -69,43 +73,29 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
     Type type = op.getType();
     Value a = op.getLhs();
     Value b = op.getRhs();
-    Value plusOne = createConst(loc, type, 1, rewriter);
+
     Value zero = createConst(loc, type, 0, rewriter);
-    Value minusOne = createConst(loc, type, -1, rewriter);
-    // Compute x = (b>0) ? -1 : 1.
-    Value compare =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
-    Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
-    // Compute positive res: 1 + ((x+a)/b).
-    Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
-    Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
-    Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
-    // Compute negative res: - ((-a)/b).
-    Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
-    Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
-    Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
-    // Result is (a*b>0) ? pos result : neg result.
-    // Note, we want to avoid using a*b because of possible overflow.
-    // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
-    // not particuliarly care if a*b<0 is true or false when b is zero
-    // as this will result in an illegal divide. So `a*b<0` can be reformulated
-    // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
-    // We pick the first expression here.
+    Value one = createConst(loc, type, 1, rewriter);
+
+    Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
+    Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
+    Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne, a, product);
+
     Value aNeg =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
-    Value aPos =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
     Value bNeg =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
-    Value bPos =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
-    Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
-    Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
-    Value compareRes =
-        rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    // Perform substitution and return success.
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
-                                                 negRes);
+
+    Value signEqual = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, aNeg, bNeg);
+    Value cond =
+        rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
+
+    Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
+                                                 quotient);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7daf4ef8717bc..e0d974ea74041 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -600,28 +600,20 @@ func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>
 // -----
 
 // CHECK-LABEL: @ceildivsi
-// CHECK-SAME: %[[ARG0:.*]]: i64) -> i64
-func.func @ceildivsi(%arg0 : i64) -> i64 {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(1 : i64) : i64
-  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(-1 : i64) : i64
-  // CHECK: %[[CMP0:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST2]], %[[CST0]] : i1, i64
-  // CHECK: %[[ADD0:.*]] = llvm.add %[[SEL0]], %[[ARG0]] : i64
-  // CHECK: %[[DIV0:.*]] = llvm.sdiv %[[ADD0]], %[[ARG0]] : i64
-  // CHECK: %[[ADD1:.*]] = llvm.add %[[DIV0]], %[[CST0]] : i64
-  // CHECK: %[[SUB0:.*]] = llvm.sub %[[CST1]], %[[ARG0]] : i64
-  // CHECK: %[[DIV1:.*]] = llvm.sdiv %[[SUB0]], %[[ARG0]] : i64
-  // CHECK: %[[SUB1:.*]] = llvm.sub %[[CST1]], %[[DIV1]] : i64
-  // CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[CMP2:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[CMP3:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[CMP4:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[AND0:.*]] = llvm.and %[[CMP1]], %[[CMP3]] : i1
-  // CHECK: %[[AND1:.*]] = llvm.and %[[CMP2]], %[[CMP4]] : i1
-  // CHECK: %[[OR:.*]] = llvm.or %[[AND0]], %[[AND1]] : i1
-  // CHECK: %[[SEL1:.*]] = llvm.select %[[OR]], %[[ADD1]], %[[SUB1]] : i1, i64
-  %0 = arith.ceildivsi %arg0, %arg0 : i64
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64
+func.func @ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
+  // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[ONE:.+]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: %[[DIV:.+]] = llvm.sdiv %[[ARG0]], %[[ARG1]] : i64
+  // CHECK: %[[MUL:.+]] = llvm.mul %[[DIV]], %[[ARG1]] : i64
+  // CHECK: %[[NEXACT:.+]] = llvm.icmp "ne" %[[ARG0]], %[[MUL]] : i64
+  // CHECK: %[[NNEG:.+]] = llvm.icmp "slt" %[[ARG0]], %[[ZERO]] : i64
+  // CHECK: %[[MNEG:.+]] = llvm.icmp "slt" %[[ARG1]], %[[ZERO]] : i64
+  // CHECK: %[[SAMESIGN:.+]] = llvm.icmp "eq" %[[NNEG]], %[[MNEG]] : i1
+  // CHECK: %[[SHOULDROUND:.+]] = llvm.and %[[NEXACT]], %[[SAMESIGN]] : i1
+  // CHECK: %[[CEIL:.+]] = llvm.add %[[DIV]], %[[ONE]] : i64
+  // CHECK: %[[RES:.+]] = llvm.select %[[SHOULDROUND]], %[[CEIL]], %[[DIV]] : i1, i64
+  %0 = arith.ceildivsi %arg0, %arg1 : i64
   return %0: i64
 }
 
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc004..bdf022642b717 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -7,25 +7,17 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
   %res = arith.ceildivsi %arg0, %arg1 : i32
   return %res : i32
 
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
 // CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK:           [[MINONE:%.+]] = arith.constant -1 : i32
-// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
-// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
+// CHECK:           [[DIV:%.+]] = arith.divsi %arg0, %arg1 : i32
+// CHECK:           [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : i32
+// CHECK:           [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : i32
+// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : i32
+// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : i32
+// CHECK:           [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK:           [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK:           [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
+// CHECK:           [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : i32
 }
 
 // -----
@@ -37,25 +29,18 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
   %res = arith.ceildivsi %arg0, %arg1 : index
   return %res : index
 
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
 // CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK:           [[MINONE:%.+]] = arith.constant -1 : index
-// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
-// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
+// CHECK:           [[DIV:%.+]] = arith.divsi %arg0, %arg1 : index
+// CHECK:           [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : index
+// CHECK:           [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : index
+// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : index
+// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : index
+// CHECK:           [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK:           [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK:           [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : index
+// CHECK:           [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : index
+
 }
 
 // -----

``````````

</details>


https://github.com/llvm/llvm-project/pull/133774


More information about the Mlir-commits mailing list