[Mlir-commits] [mlir] 1a8c613 - [mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (#82133)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Feb 18 01:17:06 PST 2024


Author: Felix Schneider
Date: 2024-02-18T10:17:03+01:00
New Revision: 1a8c6130f60fe517fb722ab4309997ce7b638234

URL: https://github.com/llvm/llvm-project/commit/1a8c6130f60fe517fb722ab4309997ce7b638234
DIFF: https://github.com/llvm/llvm-project/commit/1a8c6130f60fe517fb722ab4309997ce7b638234.diff

LOG: [mlir][arith] Align shift Ops with LLVM instructions on allowed shift amounts (#82133)

This patch aligns the shift Ops in `arith` with respective LLVM instructions.
Specifically, shifting by an amount equal to the bitwidth of the operand
is now defined to return poison.

Relevant discussion:
https://discourse.llvm.org/t/some-question-on-the-semantics-of-the-arith-dialect/74861/10
Relevant issue: https://github.com/llvm/llvm-project/issues/80960

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4babbe80e285f7..c9df50d0395d1f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -788,7 +788,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
     The `shli` operation shifts the integer value of the first operand to the left 
     by the integer value of the second operand. The second operand is interpreted as 
     unsigned. The low order bits are filled with zeros. If the value of the second 
-    operand is greater than the bitwidth of the first operand, then the 
+    operand is greater or equal than the bitwidth of the first operand, then the
     operation returns poison.
 
     This op supports `nuw`/`nsw` overflow flags which stands stand for
@@ -818,8 +818,8 @@ def Arith_ShRUIOp : Arith_TotalIntBinaryOp<"shrui"> {
     The `shrui` operation shifts an integer value of the first operand to the right 
     by the value of the second operand. The first operand is interpreted as unsigned,
     and the second operand is interpreted as unsigned. The high order bits are always 
-    filled with zeros. If the value of the second operand is greater than the bitwidth
-    of the first operand, then the operation returns poison.
+    filled with zeros. If the value of the second operand is greater or equal than the
+    bitwidth of the first operand, then the operation returns poison.
 
     Example:
 
@@ -844,8 +844,8 @@ def Arith_ShRSIOp : Arith_TotalIntBinaryOp<"shrsi"> {
     and the second operand is interpreter as unsigned. The high order bits in the 
     output are filled with copies of the most-significant bit of the shifted value 
     (which means that the sign of the value is preserved). If the value of the second 
-    operand is greater than bitwidth of the first operand, then the operation returns 
-    poison.
+    operand is greater or equal than bitwidth of the first operand, then the operation
+    returns poison.
 
     Example:
 

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 275c2debe9a6fc..0f71c19c23b654 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2379,11 +2379,11 @@ OpFoldResult arith::ShLIOp::fold(FoldAdaptor adaptor) {
   // shli(x, 0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();
-  // Don't fold if shifting more than the bit width.
+  // Don't fold if shifting more or equal than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
-        bounded = b.ule(b.getBitWidth());
+        bounded = b.ult(b.getBitWidth());
         return a.shl(b);
       });
   return bounded ? result : Attribute();
@@ -2397,11 +2397,11 @@ OpFoldResult arith::ShRUIOp::fold(FoldAdaptor adaptor) {
   // shrui(x, 0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();
-  // Don't fold if shifting more than the bit width.
+  // Don't fold if shifting more or equal than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
-        bounded = b.ule(b.getBitWidth());
+        bounded = b.ult(b.getBitWidth());
         return a.lshr(b);
       });
   return bounded ? result : Attribute();
@@ -2415,11 +2415,11 @@ OpFoldResult arith::ShRSIOp::fold(FoldAdaptor adaptor) {
   // shrsi(x, 0) -> x
   if (matchPattern(adaptor.getRhs(), m_Zero()))
     return getLhs();
-  // Don't fold if shifting more than the bit width.
+  // Don't fold if shifting more or equal than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
       adaptor.getOperands(), [&](const APInt &a, const APInt &b) {
-        bounded = b.ule(b.getBitWidth());
+        bounded = b.ult(b.getBitWidth());
         return a.ashr(b);
       });
   return bounded ? result : Attribute();

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index f128b13e9f732d..cb98a10048a309 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2179,6 +2179,17 @@ func.func @nofoldShl2() -> i64 {
   return %r : i64
 }
 
+// CHECK-LABEL: @nofoldShl3(
+// CHECK: %[[res:.+]] = arith.shli
+// CHECK: return %[[res]]
+func.func @nofoldShl3() -> i64 {
+  %c1 = arith.constant 1 : i64
+  %c64 = arith.constant 64 : i64
+  // Note that this should return Poison in the future.
+  %r = arith.shli %c1, %c64 : i64
+  return %r : i64
+}
+
 // CHECK-LABEL: @foldShru(
 // CHECK: %[[res:.+]] = arith.constant 2 : i64
 // CHECK: return %[[res]]
@@ -2219,6 +2230,17 @@ func.func @nofoldShru2() -> i64 {
   return %r : i64
 }
 
+// CHECK-LABEL: @nofoldShru3(
+// CHECK: %[[res:.+]] = arith.shrui
+// CHECK: return %[[res]]
+func.func @nofoldShru3() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c64 = arith.constant 64 : i64
+  // Note that this should return Poison in the future.
+  %r = arith.shrui %c1, %c64 : i64
+  return %r : i64
+}
+
 // CHECK-LABEL: @foldShrs(
 // CHECK: %[[res:.+]] = arith.constant 2 : i64
 // CHECK: return %[[res]]
@@ -2259,6 +2281,17 @@ func.func @nofoldShrs2() -> i64 {
   return %r : i64
 }
 
+// CHECK-LABEL: @nofoldShrs3(
+// CHECK: %[[res:.+]] = arith.shrsi
+// CHECK: return %[[res]]
+func.func @nofoldShrs3() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c64 = arith.constant 64 : i64
+  // Note that this should return Poison in the future.
+  %r = arith.shrsi %c1, %c64 : i64
+  return %r : i64
+}
+
 // -----
 
 // CHECK-LABEL: @test_negf(


        


More information about the Mlir-commits mailing list