[Mlir-commits] [mlir] [mlir][arith] fix ceildivsi lowering (PR #106696)
Joachim Giron
llvmlistbot at llvm.org
Thu Sep 5 09:23:10 PDT 2024
https://github.com/jgiron42 updated https://github.com/llvm/llvm-project/pull/106696
>From 9d99eafb0f0dceecbed92af4c4df0c80117dd850 Mon Sep 17 00:00:00 2001
From: joachim <jgiron at student.42.fr>
Date: Fri, 30 Aug 2024 11:33:56 +0200
Subject: [PATCH] [mlir][arith] fix ceildivsi lowering
This commit fix the overflow in the case of
ceildivsi( <signed_type_min>, <positive_integer> )
---
mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp | 8 +++-----
mlir/test/Dialect/Arith/expand-ops.mlir | 12 ++++--------
2 files changed, 7 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a710113..a37a18135386e1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -60,7 +60,7 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
/// Expands CeilDivSIOp (n, m) into
/// 1) x = (m > 0) ? -1 : 1
-/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+/// 2) (n*m>0) ? ((n+x) / m) + 1 : n / m
struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -80,10 +80,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
- // Compute negative res: - ((-a)/b).
- Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
- Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
- Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
+ // Compute negative res: a/b.
+ Value negRes = rewriter.create<arith::DivSIOp>(loc, a, b);
// Result is (a*b>0) ? pos result : neg result.
// Note, we want to avoid using a*b because of possible overflow.
// The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc0041..3ba89ab740ef18 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -15,9 +15,7 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
+// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
@@ -25,7 +23,7 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
}
// -----
@@ -45,9 +43,7 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
// CHECK: [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
// CHECK: [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
// CHECK: [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK: [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK: [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK: [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
+// CHECK: [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
// CHECK: [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
// CHECK: [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
// CHECK: [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
@@ -55,7 +51,7 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
// CHECK: [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
// CHECK: [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK: [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
}
// -----
More information about the Mlir-commits
mailing list