[Mlir-commits] [mlir] [mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (PR #82133)

Felix Schneider llvmlistbot at llvm.org
Sun Feb 18 00:51:30 PST 2024


https://github.com/ubfx updated https://github.com/llvm/llvm-project/pull/82133

>From d5dabc96a57974c3662aefe56ec29cddd40b7596 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sat, 17 Feb 2024 23:18:04 +0100
Subject: [PATCH 1/3] [mlir][arith] Update documentation of shift Ops on
 allowed shift amounts

This patch aligns the documentation of the shift Ops in `arith` with
respective LLVM instructions. Specifically, it is now stated that
shifting by an amount equal to the operand bitwidth returns poison.
---
 mlir/include/mlir/Dialect/Arith/IR/ArithOps.td | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4babbe80e285f7..c9df50d0395d1f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
     The `shli` operation shifts the integer value of the first operand to the left 
     by the integer value of the second operand. The second operand is interpreted as 
     unsigned. The low order bits are filled with zeros. If the value of the second 
-    operand is greater than the bitwidth of the first operand, then the 
+    operand is greater or equal than the bitwidth of the first operand, then the
     operation returns poison.
 
     This op supports `nuw`/`nsw` overflow flags which stands stand for
@@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
     The `shrui` operation shifts an integer value of the first operand to the right 
     by the value of the second operand. The first operand is interpreted as unsigned,
     and the second operand is interpreted as unsigned. The high order bits are always 
-    filled with zeros. If the value of the second operand is greater than the bitwidth
-    of the first operand, then the operation returns poison.
+    filled with zeros. If the value of the second operand is greater or equal than the
+    bitwidth of the first operand, then the operation returns poison.
 
     Example:
 
@@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
     and the second operand is interpreter as unsigned. The high order bits in the 
     output are filled with copies of the most-significant bit of the shifted value 
     (which means that the sign of the value is preserved). If the value of the second 
-    operand is greater than bitwidth of the first operand, then the operation returns 
-    poison.
+    operand is greater or equal than bitwidth of the first operand, then the operation
+    returns poison.
 
     Example:
 

>From 9a917e14086bc0b87116446d53486989a4fc2aa1 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 18 Feb 2024 08:18:38 +0100
Subject: [PATCH 2/3] Update folders and add tests

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp    |  8 +++---
 mlir/test/Dialect/Arith/canonicalize.mlir | 33 +++++++++++++++++++++++
 2 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 275c2debe9a6fc..64729261fa8272 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2383,7 +2383,7 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
-        bounded = b.ule(b.getBitWidth());
+        bounded = b.ult(b.getBitWidth());
         return a.shl(b);
       });
   return bounded ? result : Attribute();
@@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
   // shrui(x, 0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();
-  // Don't fold if shifting more than the bit width.
+  // Don't fold if shifting more or equal than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
-        bounded = b.ule(b.getBitWidth());
+        bounded = b.ult(b.getBitWidth());
         return a.lshr(b);
       });
   return bounded ? result : Attribute();
@@ -2419,7 +2419,7 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
-        bounded = b.ule(b.getBitWidth());
+        bounded = b.ult(b.getBitWidth());
         return a.ashr(b);
       });
   return bounded ? result : Attribute();
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f128b13e9f732d..cb98a10048a309 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 {
   return %r : i64
 }
 
+// CHECK-LABEL: @nofoldShl3(
+// CHECK: %[[res:.+]] = arith.shli
+// CHECK: return %[[res]]
+func.func @nofoldShl3() -> i64 {
+  %c1 = arith.constant 1 : i64
+  %c64 = arith.constant 64 : i64
+  // Note that this should return Poison in the future.
+  %r = arith.shli %c1, %c64 : i64
+  return %r : i64
+}
+
 // CHECK-LABEL: @foldShru(
 // CHECK: %[[res:.+]] = arith.constant 2 : i64
 // CHECK: return %[[res]]
@@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 {
   return %r : i64
 }
 
+// CHECK-LABEL: @nofoldShru3(
+// CHECK: %[[res:.+]] = arith.shrui
+// CHECK: return %[[res]]
+func.func @nofoldShru3() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c64 = arith.constant 64 : i64
+  // Note that this should return Poison in the future.
+  %r = arith.shrui %c1, %c64 : i64
+  return %r : i64
+}
+
 // CHECK-LABEL: @foldShrs(
 // CHECK: %[[res:.+]] = arith.constant 2 : i64
 // CHECK: return %[[res]]
@@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 {
   return %r : i64
 }
 
+// CHECK-LABEL: @nofoldShrs3(
+// CHECK: %[[res:.+]] = arith.shrsi
+// CHECK: return %[[res]]
+func.func @nofoldShrs3() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c64 = arith.constant 64 : i64
+  // Note that this should return Poison in the future.
+  %r = arith.shrsi %c1, %c64 : i64
+  return %r : i64
+}
+
 // -----
 
 // CHECK-LABEL: @test_negf(

>From a752d1ca9fe7cc4364bc9c42e2d45fe24e2d8d90 Mon Sep 17 00:00:00 2001
From: Felix Schneider <fx.schn at gmail.com>
Date: Sun, 18 Feb 2024 09:51:11 +0100
Subject: [PATCH 3/3] Update comments

---
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 64729261fa8272..0f71c19c23b654 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2379,7 +2379,7 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
   // shli(x, 0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();
-  // Don't fold if shifting more than the bit width.
+  // Don't fold if shifting more or equal than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
@@ -2415,7 +2415,7 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
   // shrsi(x, 0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();
-  // Don't fold if shifting more than the bit width.
+  // Don't fold if shifting more or equal than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {



More information about the Mlir-commits mailing list