[flang-commits] [flang] ac87d6b - [mlir][arith] Fold `arith.cmpi eq, %val, %one : i1` -> `%val` and `arith.cmpi ne, %val, %zero : i1 -> %val` (#124436)

via flang-commits flang-commits at lists.llvm.org
Mon Jan 27 03:28:12 PST 2025


Author: Ivan Butygin
Date: 2025-01-27T14:28:09+03:00
New Revision: ac87d6b03642eca3901a7776d73be368299402e9

URL: https://github.com/llvm/llvm-project/commit/ac87d6b03642eca3901a7776d73be368299402e9
DIFF: https://github.com/llvm/llvm-project/commit/ac87d6b03642eca3901a7776d73be368299402e9.diff

LOG: [mlir][arith] Fold `arith.cmpi eq, %val, %one : i1` -> `%val` and `arith.cmpi ne, %val, %zero : i1 -> %val` (#124436)

https://alive2.llvm.org/ce/z/dNZMdC

Added: 
    

Modified: 
    flang/test/Lower/Intrinsics/ieee_next.f90
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Dialect/Arith/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/flang/test/Lower/Intrinsics/ieee_next.f90 b/flang/test/Lower/Intrinsics/ieee_next.f90
index fa9692b83bc874..eb9cc028368a5a 100644
--- a/flang/test/Lower/Intrinsics/ieee_next.f90
+++ b/flang/test/Lower/Intrinsics/ieee_next.f90
@@ -131,9 +131,8 @@ program p
   ! CHECK:     %[[V_106:[0-9]+]] = arith.bitcast %[[V_104]] : f32 to i32
   ! CHECK:     %[[V_107:[0-9]+]] = arith.shrui %[[V_106]], %c31{{.*}} : i32
   ! CHECK:     %[[V_108:[0-9]+]] = fir.convert %[[V_107]] : (i32) -> i1
-  ! CHECK:     %[[V_109:[0-9]+]] = arith.cmpi ne, %[[V_108]], %false{{[_0-9]*}} : i1
   ! CHECK:     %[[V_110:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_104]]) <{bit = 516 : i32}> : (f32) -> i1
-  ! CHECK:     %[[V_111:[0-9]+]] = arith.andi %[[V_110]], %[[V_109]] : i1
+  ! CHECK:     %[[V_111:[0-9]+]] = arith.andi %[[V_110]], %[[V_108]] : i1
   ! CHECK:     %[[V_112:[0-9]+]] = arith.ori %[[V_105]], %[[V_111]] : i1
   ! CHECK:     %[[V_113:[0-9]+]] = fir.if %[[V_112]] -> (f32) {
   ! CHECK:       %[[V_202:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_104]]) <{bit = 1 : i32}> : (f32) -> i1
@@ -149,7 +148,7 @@ program p
   ! CHECK:       } else {
   ! CHECK-DAG:     %[[V_204:[0-9]+]] = arith.subi %[[V_106]], %c1{{.*}} : i32
   ! CHECK-DAG:     %[[V_205:[0-9]+]] = arith.addi %[[V_106]], %c1{{.*}} : i32
-  ! CHECK:         %[[V_206:[0-9]+]] = arith.select %[[V_109]], %[[V_205]], %[[V_204]] : i32
+  ! CHECK:         %[[V_206:[0-9]+]] = arith.select %[[V_108]], %[[V_205]], %[[V_204]] : i32
   ! CHECK:         %[[V_207:[0-9]+]] = arith.bitcast %[[V_206]] : i32 to f32
   ! CHECK:         fir.result %[[V_207]] : f32
   ! CHECK:       }
@@ -253,9 +252,8 @@ program p
   ! CHECK:     %[[V_182:[0-9]+]] = arith.bitcast %[[V_180]] : f128 to i128
   ! CHECK:     %[[V_183:[0-9]+]] = arith.shrui %[[V_182]], %c127{{.*}} : i128
   ! CHECK:     %[[V_184:[0-9]+]] = fir.convert %[[V_183]] : (i128) -> i1
-  ! CHECK:     %[[V_185:[0-9]+]] = arith.cmpi ne, %[[V_184]], %false{{[_0-9]*}} : i1
   ! CHECK:     %[[V_186:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_180]]) <{bit = 516 : i32}> : (f128) -> i1
-  ! CHECK:     %[[V_187:[0-9]+]] = arith.andi %[[V_186]], %[[V_185]] : i1
+  ! CHECK:     %[[V_187:[0-9]+]] = arith.andi %[[V_186]], %[[V_184]] : i1
   ! CHECK:     %[[V_188:[0-9]+]] = arith.ori %[[V_181]], %[[V_187]] : i1
   ! CHECK:     %[[V_189:[0-9]+]] = fir.if %[[V_188]] -> (f128) {
   ! CHECK:       %[[V_202:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_180]]) <{bit = 1 : i32}> : (f128) -> i1
@@ -271,7 +269,7 @@ program p
   ! CHECK:       } else {
   ! CHECK-DAG:     %[[V_204:[0-9]+]] = arith.subi %[[V_182]], %c1{{.*}} : i128
   ! CHECK-DAG:     %[[V_205:[0-9]+]] = arith.addi %[[V_182]], %c1{{.*}} : i128
-  ! CHECK:         %[[V_206:[0-9]+]] = arith.select %[[V_185]], %[[V_205]], %[[V_204]] : i128
+  ! CHECK:         %[[V_206:[0-9]+]] = arith.select %[[V_184]], %[[V_205]], %[[V_204]] : i128
   ! CHECK:         %[[V_207:[0-9]+]] = arith.bitcast %[[V_206]] : i128 to f128
   ! CHECK:         fir.result %[[V_207]] : f128
   ! CHECK:       }

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7ca104691e6df6..75d59ba8c1a108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1865,6 +1865,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
           getPredicate() == arith::CmpIPredicate::ne)
         return extOp.getOperand();
     }
+
+    // arith.cmpi ne, %val, %zero : i1 -> %val
+    if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+        getPredicate() == arith::CmpIPredicate::ne)
+      return getLhs();
+  }
+
+  if (matchPattern(adaptor.getRhs(), m_One())) {
+    // arith.cmpi eq, %val, %one : i1 -> %val
+    if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+        getPredicate() == arith::CmpIPredicate::eq)
+      return getLhs();
   }
 
   // Move constant to the right side.

diff  --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 522711b08f289d..3a16ee3d4f8fde 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -160,6 +160,78 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
   return %res1, %res2 : i32, i32
 }
 
+// CHECK-LABEL: @cmpiI1eq
+//  CHECK-SAME: (%[[ARG:.*]]: i1)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1eq(%arg0: i1) -> i1 {
+  %one = arith.constant 1 : i1
+  %res = arith.cmpi eq, %arg0, %one : i1
+  return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVec
+//  CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1eqVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+  %one = arith.constant dense<1> : vector<4xi1>
+  %res = arith.cmpi eq, %arg0, %one : vector<4xi1>
+  return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1ne
+//  CHECK-SAME: (%[[ARG:.*]]: i1)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1ne(%arg0: i1) -> i1 {
+  %zero = arith.constant 0 : i1
+  %res = arith.cmpi ne, %arg0, %zero : i1
+  return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVec
+//  CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1neVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+  %zero = arith.constant dense<0> : vector<4xi1>
+  %res = arith.cmpi ne, %arg0, %zero : vector<4xi1>
+  return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1eqLhs
+//  CHECK-SAME: (%[[ARG:.*]]: i1)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1eqLhs(%arg0: i1) -> i1 {
+  %one = arith.constant 1 : i1
+  %res = arith.cmpi eq, %one, %arg0  : i1
+  return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVecLhs
+//  CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1eqVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+  %one = arith.constant dense<1> : vector<4xi1>
+  %res = arith.cmpi eq, %one, %arg0 : vector<4xi1>
+  return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1neLhs
+//  CHECK-SAME: (%[[ARG:.*]]: i1)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1neLhs(%arg0: i1) -> i1 {
+  %zero = arith.constant 0 : i1
+  %res = arith.cmpi ne, %zero, %arg0 : i1
+  return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVecLhs
+//  CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+//       CHECK: return %[[ARG]]
+func.func @cmpiI1neVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+  %zero = arith.constant dense<0> : vector<4xi1>
+  %res = arith.cmpi ne, %zero, %arg0 : vector<4xi1>
+  return %res : vector<4xi1>
+}
+
 // Test case: Folding of comparisons with equal operands.
 // CHECK-LABEL: @cmpi_equal_operands
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true


        


More information about the flang-commits mailing list