[Mlir-commits] [mlir] [mlir][arith] fix ceildivsi lowering (PR #106696)

Joachim Giron llvmlistbot at llvm.org
Thu Sep 5 09:23:10 PDT 2024


https://github.com/jgiron42 updated https://github.com/llvm/llvm-project/pull/106696

>From 9d99eafb0f0dceecbed92af4c4df0c80117dd850 Mon Sep 17 00:00:00 2001
From: joachim <jgiron at student.42.fr>
Date: Fri, 30 Aug 2024 11:33:56 +0200
Subject: [PATCH] [mlir][arith] fix ceildivsi lowering

This commit fix the overflow in the case of
ceildivsi( <signed_type_min>, <positive_integer> )
---
 mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp |  8 +++-----
 mlir/test/Dialect/Arith/expand-ops.mlir         | 12 ++++--------
 2 files changed, 7 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
index 54be644a710113..a37a18135386e1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
@@ -60,7 +60,7 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
 
 /// Expands CeilDivSIOp (n, m) into
 ///   1) x = (m > 0) ? -1 : 1
-///   2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
+///   2) (n*m>0) ? ((n+x) / m) + 1 : n / m
 struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
@@ -80,10 +80,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
     Value xPlusA = rewriter.create<arith::AddIOp>(loc, x, a);
     Value xPlusADivB = rewriter.create<arith::DivSIOp>(loc, xPlusA, b);
     Value posRes = rewriter.create<arith::AddIOp>(loc, plusOne, xPlusADivB);
-    // Compute negative res: - ((-a)/b).
-    Value minusA = rewriter.create<arith::SubIOp>(loc, zero, a);
-    Value minusADivB = rewriter.create<arith::DivSIOp>(loc, minusA, b);
-    Value negRes = rewriter.create<arith::SubIOp>(loc, zero, minusADivB);
+    // Compute negative res: a/b.
+    Value negRes = rewriter.create<arith::DivSIOp>(loc, a, b);
     // Result is (a*b>0) ? pos result : neg result.
     // Note, we want to avoid using a*b because of possible overflow.
     // The case that matters are a>0, a==0, a<0, b>0 and b<0. We do
diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir
index 174eb468cc0041..3ba89ab740ef18 100644
--- a/mlir/test/Dialect/Arith/expand-ops.mlir
+++ b/mlir/test/Dialect/Arith/expand-ops.mlir
@@ -15,9 +15,7 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
 // CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : i32
 // CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : i32
 // CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : i32
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : i32
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : i32
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : i32
+// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : i32
 // CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : i32
 // CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : i32
 // CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : i32
@@ -25,7 +23,7 @@ func.func @ceildivi(%arg0: i32, %arg1: i32) -> (i32) {
 // CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
 // CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
 // CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : i32
+// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : i32
 }
 
 // -----
@@ -45,9 +43,7 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
 // CHECK:           [[TRUE1:%.+]] = arith.addi [[X]], [[ARG0]] : index
 // CHECK:           [[TRUE2:%.+]] = arith.divsi [[TRUE1]], [[ARG1]] : index
 // CHECK:           [[TRUE3:%.+]] = arith.addi [[ONE]], [[TRUE2]] : index
-// CHECK:           [[FALSE1:%.+]] = arith.subi [[ZERO]], [[ARG0]] : index
-// CHECK:           [[FALSE2:%.+]] = arith.divsi [[FALSE1]], [[ARG1]] : index
-// CHECK:           [[FALSE3:%.+]] = arith.subi [[ZERO]], [[FALSE2]] : index
+// CHECK:           [[FALSE:%.+]] = arith.divsi [[ARG0]], [[ARG1]] : index
 // CHECK:           [[NNEG:%.+]] = arith.cmpi slt, [[ARG0]], [[ZERO]] : index
 // CHECK:           [[NPOS:%.+]] = arith.cmpi sgt, [[ARG0]], [[ZERO]] : index
 // CHECK:           [[MNEG:%.+]] = arith.cmpi slt, [[ARG1]], [[ZERO]] : index
@@ -55,7 +51,7 @@ func.func @ceildivi_index(%arg0: index, %arg1: index) -> (index) {
 // CHECK:           [[TERM1:%.+]] = arith.andi [[NNEG]], [[MNEG]] : i1
 // CHECK:           [[TERM2:%.+]] = arith.andi [[NPOS]], [[MPOS]] : i1
 // CHECK:           [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
-// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE3]] : index
+// CHECK:           [[RES:%.+]] = arith.select [[CMP2]], [[TRUE3]], [[FALSE]] : index
 }
 
 // -----



More information about the Mlir-commits mailing list