[Mlir-commits] [mlir] 6c295a9 - [mlir][Arith] Fold integer shift op with zero.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Dec 30 01:19:31 PST 2022


Author: jacquesguan
Date: 2022-12-30T17:19:23+08:00
New Revision: 6c295a932d26681f07037d7289df405e36350dd4

URL: https://github.com/llvm/llvm-project/commit/6c295a932d26681f07037d7289df405e36350dd4
DIFF: https://github.com/llvm/llvm-project/commit/6c295a932d26681f07037d7289df405e36350dd4.diff

LOG: [mlir][Arith] Fold integer shift op with zero.

This revision folds arith.shrui, arith.shrsi and arith.shli with zero
rhs to lhs.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140749

Added: 
    

Modified: 
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 844905c20210a..f812b3c61b366 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2221,6 +2221,9 @@ LogicalResult arith::SelectOp::verify() {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
+  // shli(x, 0) -> x
+  if (matchPattern(getRhs(), m_Zero()))
+    return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
@@ -2236,6 +2239,9 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
+  // shrui(x, 0) -> x
+  if (matchPattern(getRhs(), m_Zero()))
+    return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(
@@ -2251,6 +2257,9 @@ OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
+  // shrsi(x, 0) -> x
+  if (matchPattern(getRhs(), m_Zero()))
+    return getLhs();
   // Don't fold if shifting more than the bit width.
   bool bounded = false;
   auto result = constFoldBinaryOp<IntegerAttr>(

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index d181ae9d9d111..02cbaa28ab3f7 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2107,3 +2107,30 @@ func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
   %hi = arith.trunci %sh: i64 to i32
   return %hi : i32
 }
+
+// CHECK-LABEL: @foldShli0
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+//       CHECK:   return %[[ARG]] : i64
+func.func @foldShli0(%x : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %r = arith.shli %x, %c0 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShrui0
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+//       CHECK:   return %[[ARG]] : i64
+func.func @foldShrui0(%x : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %r = arith.shrui %x, %c0 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShrsi0
+// CHECK-SAME: (%[[ARG:.*]]: i64)
+//       CHECK:   return %[[ARG]] : i64
+func.func @foldShrsi0(%x : i64) -> i64 {
+  %c0 = arith.constant 0 : i64
+  %r = arith.shrsi %x, %c0 : i64
+  return %r : i64
+}


        


More information about the Mlir-commits mailing list