[Mlir-commits] [mlir] [mlir] [arith] Fix ceildivsi lowering in arith-expand (PR #133774)

Fehr Mathieu llvmlistbot at llvm.org
Mon Mar 31 11:44:36 PDT 2025


https://github.com/math-fehr created https://github.com/llvm/llvm-project/pull/133774

This fixes the current lowering of `arith.ceildivsi` in the arith-expand pass, which was previously incorrect. The new version is based on the lowering of `arith.floordivsi`, and will not introduce new undefined behavior or poison during the lowering. It also replaces one division with a multiplication.

The previous lowering of `ceildivsi(n, m)` was the following:
```
x = (m > 0) ? -1 : 1
(n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
```

This caused two problems:
* In the case where `n` is INT_MIN and `m` is positive, the result would be poison instead of an actual value
* In the case where `n` is INT_MAX and `m` is `-1`, this would trigger undefined behavior, while the original code wouldn't. This is because `n+x` would be equal to `INT_MIN` (`INT_MAX + 1`), so the `(n+x) / m` division would overflow and trigger UB.

>From f829156d28e7b42e16910ee427cd6b25398789c5 Mon Sep 17 00:00:00 2001
From: Mathieu Fehr <mathieu.fehr at gmail.com>
Date: Mon, 31 Mar 2025 19:56:34 +0100
Subject: [PATCH] [mlir] Fix ceildivsi lowering in arith-expand

This fixes the current lowering of arith.ceildivsi in the arith-expand
pass, which was previously incorrect. The new version is based on the
lowering of arith.floordivsi, and will not introduce new undefined
behavior or poison during the lowering. It also do one less division.
---
 .../Dialect/Arith/Transforms/ExpandOps.cpp    | 60 ++++++++-----------
 .../Conversion/ArithToLLVM/arith-to-llvm.mlir | 36 +++++------
 mlir/test/Dialect/Arith/expand-ops.mlir       | 57 +++++++-----------
 3 files changed, 60 insertions(+), 93 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a71011..2d627e523cde5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -58,9 +58,13 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
   }
 };
 
-/// Expands CeilDivSIOp (n, m) into
-///   1) x = (m > 0) ? -1 : 1
-///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+/// Expands CeilDivSIOp (a, b) into
+/// z = a / b
+/// if (z * b != a && (a < 0) == (b < 0)) {
+///   return z + 1;
+/// } else {
+///   return z;
+/// }
 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -69,43 +73,29 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
     Type type = op.getType();
     Value a = op.getLhs();
     Value b = op.getRhs();
-    Value plusOne = createConst(loc, type, 1, rewriter);
+
     Value zero = createConst(loc, type, 0, rewriter);
-    Value minusOne = createConst(loc, type, -1, rewriter);
-    // Compute x = (b>0) ? -1 : 1.
-    Value compare =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
-    Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
-    // Compute positive res: 1 + ((x+a)/b).
-    Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
-    Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
-    Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
-    // Compute negative res: - ((-a)/b).
-    Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
-    Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
-    Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
-    // Result is (a*b>0) ? pos result : neg result.
-    // Note, we want to avoid using a*b because of possible overflow.
-    // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
-    // not particuliarly care if a*b<0 is true or false when b is zero
-    // as this will result in an illegal divide. So `a*b<0` can be reformulated
-    // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
-    // We pick the first expression here.
+    Value one = createConst(loc, type, 1, rewriter);
+
+    Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
+    Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
+    Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ne, a, product);
+
     Value aNeg =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
-    Value aPos =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
     Value bNeg =
         rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
-    Value bPos =
-        rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
-    Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
-    Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
-    Value compareRes =
-        rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
-    // Perform substitution and return success.
-    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
-                                                 negRes);
+
+    Value signEqual = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::eq, aNeg, bNeg);
+    Value cond =
+        rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
+
+    Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
+                                                 quotient);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7daf4ef8717bc..e0d974ea74041 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -600,28 +600,20 @@ func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>
 // -----
 
 // CHECK-LABEL: @ceildivsi
-// CHECK-SAME: %[[ARG0:.*]]: i64) -> i64
-func.func @ceildivsi(%arg0 : i64) -> i64 {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(1 : i64) : i64
-  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(-1 : i64) : i64
-  // CHECK: %[[CMP0:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST2]], %[[CST0]] : i1, i64
-  // CHECK: %[[ADD0:.*]] = llvm.add %[[SEL0]], %[[ARG0]] : i64
-  // CHECK: %[[DIV0:.*]] = llvm.sdiv %[[ADD0]], %[[ARG0]] : i64
-  // CHECK: %[[ADD1:.*]] = llvm.add %[[DIV0]], %[[CST0]] : i64
-  // CHECK: %[[SUB0:.*]] = llvm.sub %[[CST1]], %[[ARG0]] : i64
-  // CHECK: %[[DIV1:.*]] = llvm.sdiv %[[SUB0]], %[[ARG0]] : i64
-  // CHECK: %[[SUB1:.*]] = llvm.sub %[[CST1]], %[[DIV1]] : i64
-  // CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[CMP2:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[CMP3:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[CMP4:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
-  // CHECK: %[[AND0:.*]] = llvm.and %[[CMP1]], %[[CMP3]] : i1
-  // CHECK: %[[AND1:.*]] = llvm.and %[[CMP2]], %[[CMP4]] : i1
-  // CHECK: %[[OR:.*]] = llvm.or %[[AND0]], %[[AND1]] : i1
-  // CHECK: %[[SEL1:.*]] = llvm.select %[[OR]], %[[ADD1]], %[[SUB1]] : i1, i64
-  %0 = arith.ceildivsi %arg0, %arg0 : i64
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64
+func.func @ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
+  // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK: %[[ONE:.+]] = llvm.mlir.constant(1 : i64) : i64
+  // CHECK: %[[DIV:.+]] = llvm.sdiv %[[ARG0]], %[[ARG1]] : i64
+  // CHECK: %[[MUL:.+]] = llvm.mul %[[DIV]], %[[ARG1]] : i64
+  // CHECK: %[[NEXACT:.+]] = llvm.icmp "ne" %[[ARG0]], %[[MUL]] : i64
+  // CHECK: %[[NNEG:.+]] = llvm.icmp "slt" %[[ARG0]], %[[ZERO]] : i64
+  // CHECK: %[[MNEG:.+]] = llvm.icmp "slt" %[[ARG1]], %[[ZERO]] : i64
+  // CHECK: %[[SAMESIGN:.+]] = llvm.icmp "eq" %[[NNEG]], %[[MNEG]] : i1
+  // CHECK: %[[SHOULDROUND:.+]] = llvm.and %[[NEXACT]], %[[SAMESIGN]] : i1
+  // CHECK: %[[CEIL:.+]] = llvm.add %[[DIV]], %[[ONE]] : i64
+  // CHECK: %[[RES:.+]] = llvm.select %[[SHOULDROUND]], %[[CEIL]], %[[DIV]] : i1, i64
+  %0 = arith.ceildivsi %arg0, %arg1 : i64
   return %0: i64
 }
 
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc004..bdf022642b717 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -7,25 +7,17 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
   %res = arith.ceildivsi %arg0, %arg1 : i32
   return %res : i32
 
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
 // CHECK:           [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK:           [[MINONE:%.+]] = arith.constant -1 : i32
-// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
-// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : i32
+// CHECK:           [[DIV:%.+]] = arith.divsi %arg0, %arg1 : i32
+// CHECK:           [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : i32
+// CHECK:           [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : i32
+// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : i32
+// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : i32
+// CHECK:           [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK:           [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK:           [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
+// CHECK:           [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : i32
 }
 
 // -----
@@ -37,25 +29,18 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
   %res = arith.ceildivsi %arg0, %arg1 : index
   return %res : index
 
-// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
 // CHECK:           [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK:           [[MINONE:%.+]] = arith.constant -1 : index
-// CHECK:           [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
-// CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
-// CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
-// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK:           [[ONE:%.+]] = arith.constant 1 : index
+// CHECK:           [[DIV:%.+]] = arith.divsi %arg0, %arg1 : index
+// CHECK:           [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : index
+// CHECK:           [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : index
+// CHECK:           [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : index
+// CHECK:           [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : index
+// CHECK:           [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK:           [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK:           [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : index
+// CHECK:           [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : index
+
 }
 
 // -----



More information about the Mlir-commits mailing list