[Mlir-commits] [mlir] [mlir] [arith] Fix ceildivsi lowering in arith-expand (PR #133774)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 31 11:44:56 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Fehr Mathieu (math-fehr)
<details>
<summary>Changes</summary>
This fixes the current lowering of `arith.ceildivsi` in the arith-expand pass, which was previously incorrect. The new version is based on the lowering of `arith.floordivsi`, and will not introduce new undefined behavior or poison during the lowering. It also replaces one division with a multiplication.
The previous lowering of `ceildivsi(n, m)` was the following:
```
x = (m > 0) ? -1 : 1
(n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
```
This caused two problems:
* In the case where `n` is INT_MIN and `m` is positive, the result would be poison instead of an actual value
* In the case where `n` is INT_MAX and `m` is `-1`, this would trigger undefined behavior, while the original code wouldn't. This is because `n+x` would be equal to `INT_MIN` (`INT_MAX + 1`), so the `(n+x) / m` division would overflow and trigger UB.
---
Full diff: https://github.com/llvm/llvm-project/pull/133774.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp (+25-35)
- (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+14-22)
- (modified) mlir/test/Dialect/Arith/expand-ops.mlir (+21-36)
``````````diff
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a71011..2d627e523cde5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -58,9 +58,13 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
}
};
-/// Expands CeilDivSIOp (n, m) into
-/// 1) x = (m > 0) ? -1 : 1
-/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+/// Expands CeilDivSIOp (a, b) into
+/// z = a / b
+/// if (z * b != a && (a < 0) == (b < 0)) {
+/// return z + 1;
+/// } else {
+/// return z;
+/// }
struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -69,43 +73,29 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
Type type = op.getType();
Value a = op.getLhs();
Value b = op.getRhs();
- Value plusOne = createConst(loc, type, 1, rewriter);
+
Value zero = createConst(loc, type, 0, rewriter);
- Value minusOne = createConst(loc, type, -1, rewriter);
- // Compute x = (b>0) ? -1 : 1.
- Value compare =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
- Value x = rewriter.create<arith::SelectOp>(loc, compare, minusOne, plusOne);
- // Compute positive res: 1 + ((x+a)/b).
- Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
- Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
- Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
- // Compute negative res: - ((-a)/b).
- Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
- Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
- Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
- // Result is (a*b>0) ? pos result : neg result.
- // Note, we want to avoid using a*b because of possible overflow.
- // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
- // not particuliarly care if a*b<0 is true or false when b is zero
- // as this will result in an illegal divide. So `a*b<0` can be reformulated
- // as `(a<0 && b<0) || (a>0 && b>0)' or `(a<0 && b<0) || (a>0 && b>=0)'.
- // We pick the first expression here.
+ Value one = createConst(loc, type, 1, rewriter);
+
+ Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b);
+ Value product = rewriter.create<arith::MulIOp>(loc, quotient, b);
+ Value notEqualDivisor = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ne, a, product);
+
Value aNeg =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero);
- Value aPos =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, a, zero);
Value bNeg =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
- Value bPos =
- rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
- Value firstTerm = rewriter.create<arith::AndIOp>(loc, aNeg, bNeg);
- Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
- Value compareRes =
- rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
- // Perform substitution and return success.
- rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compareRes, posRes,
- negRes);
+
+ Value signEqual = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, aNeg, bNeg);
+ Value cond =
+ rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual);
+
+ Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne,
+ quotient);
return success();
}
};
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 7daf4ef8717bc..e0d974ea74041 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -600,28 +600,20 @@ func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>
// -----
// CHECK-LABEL: @ceildivsi
-// CHECK-SAME: %[[ARG0:.*]]: i64) -> i64
-func.func @ceildivsi(%arg0 : i64) -> i64 {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(1 : i64) : i64
- // CHECK: %[[CST1:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST2:.*]] = llvm.mlir.constant(-1 : i64) : i64
- // CHECK: %[[CMP0:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[SEL0:.*]] = llvm.select %[[CMP0]], %[[CST2]], %[[CST0]] : i1, i64
- // CHECK: %[[ADD0:.*]] = llvm.add %[[SEL0]], %[[ARG0]] : i64
- // CHECK: %[[DIV0:.*]] = llvm.sdiv %[[ADD0]], %[[ARG0]] : i64
- // CHECK: %[[ADD1:.*]] = llvm.add %[[DIV0]], %[[CST0]] : i64
- // CHECK: %[[SUB0:.*]] = llvm.sub %[[CST1]], %[[ARG0]] : i64
- // CHECK: %[[DIV1:.*]] = llvm.sdiv %[[SUB0]], %[[ARG0]] : i64
- // CHECK: %[[SUB1:.*]] = llvm.sub %[[CST1]], %[[DIV1]] : i64
- // CHECK: %[[CMP1:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[CMP2:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[CMP3:.*]] = llvm.icmp "slt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[CMP4:.*]] = llvm.icmp "sgt" %[[ARG0]], %[[CST1]] : i64
- // CHECK: %[[AND0:.*]] = llvm.and %[[CMP1]], %[[CMP3]] : i1
- // CHECK: %[[AND1:.*]] = llvm.and %[[CMP2]], %[[CMP4]] : i1
- // CHECK: %[[OR:.*]] = llvm.or %[[AND0]], %[[AND1]] : i1
- // CHECK: %[[SEL1:.*]] = llvm.select %[[OR]], %[[ADD1]], %[[SUB1]] : i1, i64
- %0 = arith.ceildivsi %arg0, %arg0 : i64
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64) -> i64
+func.func @ceildivsi(%arg0 : i64, %arg1 : i64) -> i64 {
+ // CHECK: %[[ZERO:.+]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[ONE:.+]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[DIV:.+]] = llvm.sdiv %[[ARG0]], %[[ARG1]] : i64
+ // CHECK: %[[MUL:.+]] = llvm.mul %[[DIV]], %[[ARG1]] : i64
+ // CHECK: %[[NEXACT:.+]] = llvm.icmp "ne" %[[ARG0]], %[[MUL]] : i64
+ // CHECK: %[[NNEG:.+]] = llvm.icmp "slt" %[[ARG0]], %[[ZERO]] : i64
+ // CHECK: %[[MNEG:.+]] = llvm.icmp "slt" %[[ARG1]], %[[ZERO]] : i64
+ // CHECK: %[[SAMESIGN:.+]] = llvm.icmp "eq" %[[NNEG]], %[[MNEG]] : i1
+ // CHECK: %[[SHOULDROUND:.+]] = llvm.and %[[NEXACT]], %[[SAMESIGN]] : i1
+ // CHECK: %[[CEIL:.+]] = llvm.add %[[DIV]], %[[ONE]] : i64
+ // CHECK: %[[RES:.+]] = llvm.select %[[SHOULDROUND]], %[[CEIL]], %[[DIV]] : i1, i64
+ %0 = arith.ceildivsi %arg0, %arg1 : i64
return %0: i64
}
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc004..bdf022642b717 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -7,25 +7,17 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
%res = arith.ceildivsi %arg0, %arg1 : i32
return %res : i32
-// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
-// CHECK: [[MINONE:%.+]] = arith.constant -1 : i32
-// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : i32
-// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
-// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
-// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
-// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
-// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
-// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : i32
-// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
+// CHECK: [[DIV:%.+]] = arith.divsi %arg0, %arg1 : i32
+// CHECK: [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : i32
+// CHECK: [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : i32
+// CHECK: [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : i32
+// CHECK: [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : i32
+// CHECK: [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK: [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK: [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
+// CHECK: [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : i32
}
// -----
@@ -37,25 +29,18 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
%res = arith.ceildivsi %arg0, %arg1 : index
return %res : index
-// CHECK: [[ONE:%.+]] = arith.constant 1 : index
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
-// CHECK: [[MINONE:%.+]] = arith.constant -1 : index
-// CHECK: [[CMP1:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[X:%.+]] = arith.select [[CMP1]], [[MINONE]], [[ONE]] : index
-// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
-// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
-// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
-// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
-// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
-// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[MPOS:%.+]] = arith.cmpi sgt, [[ARG1]], [[ZERO]] : index
-// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
-// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
-// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK: [[ONE:%.+]] = arith.constant 1 : index
+// CHECK: [[DIV:%.+]] = arith.divsi %arg0, %arg1 : index
+// CHECK: [[MUL:%.+]] = arith.muli [[DIV]], %arg1 : index
+// CHECK: [[NEXACT:%.+]] = arith.cmpi ne, %arg0, [[MUL]] : index
+// CHECK: [[NNEG:%.+]] = arith.cmpi slt, %arg0, [[ZERO]] : index
+// CHECK: [[MNEG:%.+]] = arith.cmpi slt, %arg1, [[ZERO]] : index
+// CHECK: [[SAMESIGN:%.+]] = arith.cmpi eq, [[NNEG]], [[MNEG]] : i1
+// CHECK: [[SHOULDROUND:%.+]] = arith.andi [[NEXACT]], [[SAMESIGN]] : i1
+// CHECK: [[CEIL:%.+]] = arith.addi [[DIV]], [[ONE]] : index
+// CHECK: [[RES:%.+]] = arith.select [[SHOULDROUND]], [[CEIL]], [[DIV]] : index
+
}
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/133774
More information about the Mlir-commits
mailing list