[flang-commits] [flang] ac87d6b - [mlir][arith] Fold `arith.cmpi eq, %val, %one : i1` -> `%val` and `arith.cmpi ne, %val, %zero : i1 -> %val` (#124436)
via flang-commits
flang-commits at lists.llvm.org
Mon Jan 27 03:28:12 PST 2025
Author: Ivan Butygin
Date: 2025-01-27T14:28:09+03:00
New Revision: ac87d6b03642eca3901a7776d73be368299402e9
URL: https://github.com/llvm/llvm-project/commit/ac87d6b03642eca3901a7776d73be368299402e9
DIFF: https://github.com/llvm/llvm-project/commit/ac87d6b03642eca3901a7776d73be368299402e9.diff
LOG: [mlir][arith] Fold `arith.cmpi eq, %val, %one : i1` -> `%val` and `arith.cmpi ne, %val, %zero : i1 -> %val` (#124436)
https://alive2.llvm.org/ce/z/dNZMdC
Added:
Modified:
flang/test/Lower/Intrinsics/ieee_next.f90
mlir/lib/Dialect/Arith/IR/ArithOps.cpp
mlir/test/Dialect/Arith/canonicalize.mlir
Removed:
################################################################################
diff --git a/flang/test/Lower/Intrinsics/ieee_next.f90 b/flang/test/Lower/Intrinsics/ieee_next.f90
index fa9692b83bc874..eb9cc028368a5a 100644
--- a/flang/test/Lower/Intrinsics/ieee_next.f90
+++ b/flang/test/Lower/Intrinsics/ieee_next.f90
@@ -131,9 +131,8 @@ program p
! CHECK: %[[V_106:[0-9]+]] = arith.bitcast %[[V_104]] : f32 to i32
! CHECK: %[[V_107:[0-9]+]] = arith.shrui %[[V_106]], %c31{{.*}} : i32
! CHECK: %[[V_108:[0-9]+]] = fir.convert %[[V_107]] : (i32) -> i1
- ! CHECK: %[[V_109:[0-9]+]] = arith.cmpi ne, %[[V_108]], %false{{[_0-9]*}} : i1
! CHECK: %[[V_110:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_104]]) <{bit = 516 : i32}> : (f32) -> i1
- ! CHECK: %[[V_111:[0-9]+]] = arith.andi %[[V_110]], %[[V_109]] : i1
+ ! CHECK: %[[V_111:[0-9]+]] = arith.andi %[[V_110]], %[[V_108]] : i1
! CHECK: %[[V_112:[0-9]+]] = arith.ori %[[V_105]], %[[V_111]] : i1
! CHECK: %[[V_113:[0-9]+]] = fir.if %[[V_112]] -> (f32) {
! CHECK: %[[V_202:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_104]]) <{bit = 1 : i32}> : (f32) -> i1
@@ -149,7 +148,7 @@ program p
! CHECK: } else {
! CHECK-DAG: %[[V_204:[0-9]+]] = arith.subi %[[V_106]], %c1{{.*}} : i32
! CHECK-DAG: %[[V_205:[0-9]+]] = arith.addi %[[V_106]], %c1{{.*}} : i32
- ! CHECK: %[[V_206:[0-9]+]] = arith.select %[[V_109]], %[[V_205]], %[[V_204]] : i32
+ ! CHECK: %[[V_206:[0-9]+]] = arith.select %[[V_108]], %[[V_205]], %[[V_204]] : i32
! CHECK: %[[V_207:[0-9]+]] = arith.bitcast %[[V_206]] : i32 to f32
! CHECK: fir.result %[[V_207]] : f32
! CHECK: }
@@ -253,9 +252,8 @@ program p
! CHECK: %[[V_182:[0-9]+]] = arith.bitcast %[[V_180]] : f128 to i128
! CHECK: %[[V_183:[0-9]+]] = arith.shrui %[[V_182]], %c127{{.*}} : i128
! CHECK: %[[V_184:[0-9]+]] = fir.convert %[[V_183]] : (i128) -> i1
- ! CHECK: %[[V_185:[0-9]+]] = arith.cmpi ne, %[[V_184]], %false{{[_0-9]*}} : i1
! CHECK: %[[V_186:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_180]]) <{bit = 516 : i32}> : (f128) -> i1
- ! CHECK: %[[V_187:[0-9]+]] = arith.andi %[[V_186]], %[[V_185]] : i1
+ ! CHECK: %[[V_187:[0-9]+]] = arith.andi %[[V_186]], %[[V_184]] : i1
! CHECK: %[[V_188:[0-9]+]] = arith.ori %[[V_181]], %[[V_187]] : i1
! CHECK: %[[V_189:[0-9]+]] = fir.if %[[V_188]] -> (f128) {
! CHECK: %[[V_202:[0-9]+]] = "llvm.intr.is.fpclass"(%[[V_180]]) <{bit = 1 : i32}> : (f128) -> i1
@@ -271,7 +269,7 @@ program p
! CHECK: } else {
! CHECK-DAG: %[[V_204:[0-9]+]] = arith.subi %[[V_182]], %c1{{.*}} : i128
! CHECK-DAG: %[[V_205:[0-9]+]] = arith.addi %[[V_182]], %c1{{.*}} : i128
- ! CHECK: %[[V_206:[0-9]+]] = arith.select %[[V_185]], %[[V_205]], %[[V_204]] : i128
+ ! CHECK: %[[V_206:[0-9]+]] = arith.select %[[V_184]], %[[V_205]], %[[V_204]] : i128
! CHECK: %[[V_207:[0-9]+]] = arith.bitcast %[[V_206]] : i128 to f128
! CHECK: fir.result %[[V_207]] : f128
! CHECK: }
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 7ca104691e6df6..75d59ba8c1a108 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1865,6 +1865,18 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) {
getPredicate() == arith::CmpIPredicate::ne)
return extOp.getOperand();
}
+
+ // arith.cmpi ne, %val, %zero : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::ne)
+ return getLhs();
+ }
+
+ if (matchPattern(adaptor.getRhs(), m_One())) {
+ // arith.cmpi eq, %val, %one : i1 -> %val
+ if (getElementTypeOrSelf(getLhs().getType()).isInteger(1) &&
+ getPredicate() == arith::CmpIPredicate::eq)
+ return getLhs();
}
// Move constant to the right side.
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 522711b08f289d..3a16ee3d4f8fde 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -160,6 +160,78 @@ func.func @selNotCond(%arg0: i1, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 :
return %res1, %res2 : i32, i32
}
+// CHECK-LABEL: @cmpiI1eq
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eq(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %arg0, %one : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %arg0, %one : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1ne
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1ne(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %arg0, %zero : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVec
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVec(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %arg0, %zero : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1eqLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqLhs(%arg0: i1) -> i1 {
+ %one = arith.constant 1 : i1
+ %res = arith.cmpi eq, %one, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1eqVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1eqVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %one = arith.constant dense<1> : vector<4xi1>
+ %res = arith.cmpi eq, %one, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
+// CHECK-LABEL: @cmpiI1neLhs
+// CHECK-SAME: (%[[ARG:.*]]: i1)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neLhs(%arg0: i1) -> i1 {
+ %zero = arith.constant 0 : i1
+ %res = arith.cmpi ne, %zero, %arg0 : i1
+ return %res : i1
+}
+
+// CHECK-LABEL: @cmpiI1neVecLhs
+// CHECK-SAME: (%[[ARG:.*]]: vector<4xi1>)
+// CHECK: return %[[ARG]]
+func.func @cmpiI1neVecLhs(%arg0: vector<4xi1>) -> vector<4xi1> {
+ %zero = arith.constant dense<0> : vector<4xi1>
+ %res = arith.cmpi ne, %zero, %arg0 : vector<4xi1>
+ return %res : vector<4xi1>
+}
+
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = arith.constant true
More information about the flang-commits
mailing list