[Mlir-commits] [mlir] 5505320 - [mlir][Arith] Add constant folder for right shift

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Mar 20 19:34:35 PDT 2022


Author: jacquesguan
Date: 2022-03-21T09:58:18+08:00
New Revision: 55053205e5fae77a92039a43b3c355e535f9d8c6

URL: https://github.com/llvm/llvm-project/commit/55053205e5fae77a92039a43b3c355e535f9d8c6
DIFF: https://github.com/llvm/llvm-project/commit/55053205e5fae77a92039a43b3c355e535f9d8c6.diff

LOG: [mlir][Arith] Add constant folder for right shift

Differential Revision: https://reviews.llvm.org/D121985

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
index 86a0bc19c3d1c..1dfcc8dd5cff0 100644
--- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td
@@ -532,6 +532,7 @@ def Arith_ShRUIOp : Arith_IntBinaryOp<"shrui"> {
     %3 = arith.shrui %1, %2 : (i8, i8) -> i8   // %3 is 0b00010100
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -556,6 +557,7 @@ def Arith_ShRSIOp : Arith_IntBinaryOp<"shrsi"> {
     %5 = arith.shrsi %4, %2 : (i8, i8) -> i8   // %5 is 0b00001100
     ```
   }];
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 812dc3b19dbbf..fd19cdac51f65 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -1871,6 +1871,36 @@ OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
   return bounded ? result : Attribute();
 }
 
+//===----------------------------------------------------------------------===//
+// ShRUIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
+  // Don't fold if shifting more than the bit width.
+  bool bounded = false;
+  auto result = constFoldBinaryOp<IntegerAttr>(
+      operands, [&](const APInt &a, const APInt &b) {
+        bounded = b.ule(b.getBitWidth());
+        return std::move(a).lshr(b);
+      });
+  return bounded ? result : Attribute();
+}
+
+//===----------------------------------------------------------------------===//
+// ShRSIOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
+  // Don't fold if shifting more than the bit width.
+  bool bounded = false;
+  auto result = constFoldBinaryOp<IntegerAttr>(
+      operands, [&](const APInt &a, const APInt &b) {
+        bounded = b.ule(b.getBitWidth());
+        return std::move(a).ashr(b);
+      });
+  return bounded ? result : Attribute();
+}
+
 //===----------------------------------------------------------------------===//
 // Atomic Enum
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index d4b551455d070..ee0cf1a46b08c 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -1022,3 +1022,83 @@ func @nofoldShl2() -> i64 {
   %r = arith.shli %c1, %cm32 : i64
   return %r : i64
 }
+
+// CHECK-LABEL: @foldShru(
+// CHECK: %[[res:.+]] = arith.constant 2 : i64
+// CHECK: return %[[res]]
+func @foldShru() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c32 = arith.constant 2 : i64
+  %r = arith.shrui %c1, %c32 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShru2(
+// CHECK: %[[res:.+]] = arith.constant 9223372036854775807 : i64
+// CHECK: return %[[res]]
+func @foldShru2() -> i64 {
+  %c1 = arith.constant -2 : i64
+  %c32 = arith.constant 1 : i64
+  %r = arith.shrui %c1, %c32 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @nofoldShru(
+// CHECK: %[[res:.+]] = arith.shrui
+// CHECK: return %[[res]]
+func @nofoldShru() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c132 = arith.constant 132 : i64
+  %r = arith.shrui %c1, %c132 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @nofoldShru2(
+// CHECK: %[[res:.+]] = arith.shrui
+// CHECK: return %[[res]]
+func @nofoldShru2() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %cm32 = arith.constant -32 : i64
+  %r = arith.shrui %c1, %cm32 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShrs(
+// CHECK: %[[res:.+]] = arith.constant 2 : i64
+// CHECK: return %[[res]]
+func @foldShrs() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c32 = arith.constant 2 : i64
+  %r = arith.shrsi %c1, %c32 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @foldShrs2(
+// CHECK: %[[res:.+]] = arith.constant -1 : i64
+// CHECK: return %[[res]]
+func @foldShrs2() -> i64 {
+  %c1 = arith.constant -2 : i64
+  %c32 = arith.constant 1 : i64
+  %r = arith.shrsi %c1, %c32 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @nofoldShrs(
+// CHECK: %[[res:.+]] = arith.shrsi
+// CHECK: return %[[res]]
+func @nofoldShrs() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %c132 = arith.constant 132 : i64
+  %r = arith.shrsi %c1, %c132 : i64
+  return %r : i64
+}
+
+// CHECK-LABEL: @nofoldShrs2(
+// CHECK: %[[res:.+]] = arith.shrsi
+// CHECK: return %[[res]]
+func @nofoldShrs2() -> i64 {
+  %c1 = arith.constant 8 : i64
+  %cm32 = arith.constant -32 : i64
+  %r = arith.shrsi %c1, %cm32 : i64
+  return %r : i64
+}


        


More information about the Mlir-commits mailing list