[llvm] [InstCombine] Fold `(x < y) ? -1 : zext(x > y)` and `(x > y) ? 1 : sext(x < y)` to `ucmp/scmp(x, y)` (PR #105272)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 20 12:19:55 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Volodymyr Vasylkun (Poseydon42)
<details>
<summary>Changes</summary>
This patch expands already existing funcionality to include these two additional folds, which are nearly identical to the ones already implemented.
Proofs: https://alive2.llvm.org/ce/z/Xy7s4j
---
Full diff: https://github.com/llvm/llvm-project/pull/105272.diff
4 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+13-4)
- (modified) llvm/test/Transforms/InstCombine/scmp.ll (+28)
- (modified) llvm/test/Transforms/InstCombine/select-select.ll (+11-31)
- (modified) llvm/test/Transforms/InstCombine/ucmp.ll (+30-2)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 1f6d5759883fd0..18ffc209f259e0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -3560,7 +3560,9 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) {
// This function tries to fold the following operations:
// (x < y) ? -1 : zext(x != y)
+// (x < y) ? -1 : zext(x > y)
// (x > y) ? 1 : sext(x != y)
+// (x > y) ? 1 : sext(x < y)
// Into ucmp/scmp(x, y), where signedness is determined by the signedness
// of the comparison in the original sequence.
Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
@@ -3589,16 +3591,23 @@ Instruction *InstCombinerImpl::foldSelectToCmp(SelectInst &SI) {
ICmpInst::isSigned(Pred) ? Intrinsic::scmp : Intrinsic::ucmp;
bool Replace = false;
+ ICmpInst::Predicate ExtendedCmpPredicate;
// (x < y) ? -1 : zext(x != y)
+ // (x < y) ? -1 : zext(x > y)
if (ICmpInst::isLT(Pred) && match(TV, m_AllOnes()) &&
- match(FV, m_ZExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
- m_Specific(RHS)))))
+ match(FV, m_ZExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
+ m_Specific(RHS)))) &&
+ (ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
+ ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
Replace = true;
// (x > y) ? 1 : sext(x != y)
+ // (x > y) ? 1 : sext(x < y)
if (ICmpInst::isGT(Pred) && match(TV, m_One()) &&
- match(FV, m_SExt(m_c_SpecificICmp(ICmpInst::ICMP_NE, m_Specific(LHS),
- m_Specific(RHS)))))
+ match(FV, m_SExt(m_c_ICmp(ExtendedCmpPredicate, m_Specific(LHS),
+ m_Specific(RHS)))) &&
+ (ExtendedCmpPredicate == ICmpInst::ICMP_NE ||
+ ICmpInst::getSwappedPredicate(ExtendedCmpPredicate) == Pred))
Replace = true;
if (Replace)
diff --git a/llvm/test/Transforms/InstCombine/scmp.ll b/llvm/test/Transforms/InstCombine/scmp.ll
index 7f374c5f9a1d64..e2312140c8c13d 100644
--- a/llvm/test/Transforms/InstCombine/scmp.ll
+++ b/llvm/test/Transforms/InstCombine/scmp.ll
@@ -223,6 +223,20 @@ define i8 @scmp_from_select_lt(i32 %x, i32 %y) {
ret i8 %r
}
+; Fold (x s< y) ? -1 : zext(x s> y) into scmp(x, y)
+define i8 @scmp_from_select_lt_and_gt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_lt_and_gt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %gt_bool = icmp sgt i32 %x, %y
+ %gt = zext i1 %gt_bool to i8
+ %lt = icmp slt i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %gt
+ ret i8 %r
+}
+
; Vector version
define <4 x i8> @scmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: define <4 x i8> @scmp_from_select_vec_lt(
@@ -264,3 +278,17 @@ define i8 @scmp_from_select_ge(i32 %x, i32 %y) {
%r = select i1 %ge, i8 %ne, i8 -1
ret i8 %r
}
+
+; Fold (x s> y) ? 1 : sext(x s< y)
+define i8 @scmp_from_select_gt_and_lt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @scmp_from_select_gt_and_lt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %lt_bool = icmp slt i32 %x, %y
+ %lt = sext i1 %lt_bool to i8
+ %gt = icmp sgt i32 %x, %y
+ %r = select i1 %gt, i8 1, i8 %lt
+ ret i8 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/select-select.ll b/llvm/test/Transforms/InstCombine/select-select.ll
index 5460ba1bc55838..1feae5ab504dcf 100644
--- a/llvm/test/Transforms/InstCombine/select-select.ll
+++ b/llvm/test/Transforms/InstCombine/select-select.ll
@@ -18,9 +18,9 @@ define float @foo1(float %a) {
define float @foo2(float %a) {
; CHECK-LABEL: @foo2(
-; CHECK-NEXT: [[B:%.*]] = fcmp ule float [[C:%.*]], 0.000000e+00
-; CHECK-NEXT: [[D:%.*]] = fcmp olt float [[C]], 1.000000e+00
-; CHECK-NEXT: [[E:%.*]] = select i1 [[D]], float [[C]], float 1.000000e+00
+; CHECK-NEXT: [[B:%.*]] = fcmp ule float [[A:%.*]], 0.000000e+00
+; CHECK-NEXT: [[TMP1:%.*]] = fcmp olt float [[A]], 1.000000e+00
+; CHECK-NEXT: [[E:%.*]] = select i1 [[TMP1]], float [[A]], float 1.000000e+00
; CHECK-NEXT: [[F:%.*]] = select i1 [[B]], float 0.000000e+00, float [[E]]
; CHECK-NEXT: ret float [[F]]
;
@@ -330,10 +330,7 @@ define i8 @strong_order_cmp_eq_ugt(i32 %a, i32 %b) {
define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_slt_sgt(
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
-; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.lt = icmp slt i32 %a, %b
@@ -345,10 +342,7 @@ define i8 @strong_order_cmp_slt_sgt(i32 %a, i32 %b) {
define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_ult_ugt(
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A]], [[B]]
-; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.lt = icmp ult i32 %a, %b
@@ -360,10 +354,7 @@ define i8 @strong_order_cmp_ult_ugt(i32 %a, i32 %b) {
define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_sgt_slt(
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A]], [[B]]
-; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.gt = icmp sgt i32 %a, %b
@@ -375,10 +366,7 @@ define i8 @strong_order_cmp_sgt_slt(i32 %a, i32 %b) {
define i8 @strong_order_cmp_ugt_ult(i32 %a, i32 %b) {
; CHECK-LABEL: @strong_order_cmp_ugt_ult(
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
-; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.gt = icmp ugt i32 %a, %b
@@ -460,8 +448,7 @@ define i8 @strong_order_cmp_ugt_ult_zext_not_oneuse(i32 %a, i32 %b) {
; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP_GT]] to i8
; CHECK-NEXT: call void @use8(i8 [[ZEXT]])
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult i32 [[A]], [[B]]
-; CHECK-NEXT: [[SEL_LT:%.*]] = select i1 [[CMP_LT]], i8 -1, i8 [[ZEXT]]
+; CHECK-NEXT: [[SEL_LT:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[A]], i32 [[B]])
; CHECK-NEXT: ret i8 [[SEL_LT]]
;
%cmp.gt = icmp ugt i32 %a, %b
@@ -477,8 +464,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
; CHECK-NEXT: [[CMP_LT:%.*]] = icmp slt i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[CMP_LT]] to i8
; CHECK-NEXT: call void @use8(i8 [[SEXT]])
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp sgt i32 [[A]], [[B]]
-; CHECK-NEXT: [[SEL_GT:%.*]] = select i1 [[CMP_GT]], i8 1, i8 [[SEXT]]
+; CHECK-NEXT: [[SEL_GT:%.*]] = call i8 @llvm.scmp.i8.i32(i32 [[A]], i32 [[B]])
; CHECK-NEXT: ret i8 [[SEL_GT]]
;
%cmp.lt = icmp slt i32 %a, %b
@@ -491,10 +477,7 @@ define i8 @strong_order_cmp_slt_sgt_sext_not_oneuse(i32 %a, i32 %b) {
define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector(
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
-; CHECK-NEXT: [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 -1, i8 -1>, <2 x i8> [[ZEXT]]
+; CHECK-NEXT: [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_LT]]
;
%cmp.gt = icmp ugt <2 x i32> %a, %b
@@ -506,10 +489,7 @@ define <2 x i8> @strong_order_cmp_ugt_ult_vector(<2 x i32> %a, <2 x i32> %b) {
define <2 x i8> @strong_order_cmp_ugt_ult_vector_poison(<2 x i32> %a, <2 x i32> %b) {
; CHECK-LABEL: @strong_order_cmp_ugt_ult_vector_poison(
-; CHECK-NEXT: [[CMP_GT:%.*]] = icmp ugt <2 x i32> [[A:%.*]], [[B:%.*]]
-; CHECK-NEXT: [[ZEXT:%.*]] = zext <2 x i1> [[CMP_GT]] to <2 x i8>
-; CHECK-NEXT: [[CMP_LT:%.*]] = icmp ult <2 x i32> [[A]], [[B]]
-; CHECK-NEXT: [[SEL_LT:%.*]] = select <2 x i1> [[CMP_LT]], <2 x i8> <i8 poison, i8 -1>, <2 x i8> [[ZEXT]]
+; CHECK-NEXT: [[SEL_LT:%.*]] = call <2 x i8> @llvm.ucmp.v2i8.v2i32(<2 x i32> [[A:%.*]], <2 x i32> [[B:%.*]])
; CHECK-NEXT: ret <2 x i8> [[SEL_LT]]
;
%cmp.gt = icmp ugt <2 x i32> %a, %b
diff --git a/llvm/test/Transforms/InstCombine/ucmp.ll b/llvm/test/Transforms/InstCombine/ucmp.ll
index ad8a57825253b0..13755f13bb0a11 100644
--- a/llvm/test/Transforms/InstCombine/ucmp.ll
+++ b/llvm/test/Transforms/InstCombine/ucmp.ll
@@ -222,6 +222,20 @@ define i8 @ucmp_from_select_lt(i32 %x, i32 %y) {
ret i8 %r
}
+; Fold (x u< y) ? -1 : zext(x u> y) into ucmp(x, y)
+define i8 @ucmp_from_select_lt_and_gt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_lt_and_gt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %gt_bool = icmp ugt i32 %x, %y
+ %gt = zext i1 %gt_bool to i8
+ %lt = icmp ult i32 %x, %y
+ %r = select i1 %lt, i8 -1, i8 %gt
+ ret i8 %r
+}
+
; Vector version
define <4 x i8> @ucmp_from_select_vec_lt(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: define <4 x i8> @ucmp_from_select_vec_lt(
@@ -349,13 +363,13 @@ define i8 @ucmp_from_select_le_neg1(i32 %x, i32 %y) {
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
; CHECK-NEXT: [[NE_BOOL:%.*]] = icmp ult i32 [[X]], [[Y]]
; CHECK-NEXT: [[NE:%.*]] = sext i1 [[NE_BOOL]] to i8
-; CHECK-NEXT: [[LE_NOT:%.*]] = icmp ugt i32 [[X]], [[Y]]
+; CHECK-NEXT: [[LE_NOT:%.*]] = icmp ult i32 [[X]], [[Y]]
; CHECK-NEXT: [[R:%.*]] = select i1 [[LE_NOT]], i8 1, i8 [[NE]]
; CHECK-NEXT: ret i8 [[R]]
;
%ne_bool = icmp ult i32 %x, %y
%ne = sext i1 %ne_bool to i8
- %le = icmp ule i32 %x, %y
+ %le = icmp uge i32 %x, %y
%r = select i1 %le, i8 %ne, i8 1
ret i8 %r
}
@@ -513,3 +527,17 @@ define i8 @ucmp_from_select_ge_neg4(i32 %x, i32 %y) {
%r = select i1 %ge, i8 %ne, i8 3
ret i8 %r
}
+
+; Fold (x > y) ? 1 : sext(x < y)
+define i8 @ucmp_from_select_gt_and_lt(i32 %x, i32 %y) {
+; CHECK-LABEL: define i8 @ucmp_from_select_gt_and_lt(
+; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
+; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.ucmp.i8.i32(i32 [[X]], i32 [[Y]])
+; CHECK-NEXT: ret i8 [[R]]
+;
+ %lt_bool = icmp ult i32 %x, %y
+ %lt = sext i1 %lt_bool to i8
+ %gt = icmp ugt i32 %x, %y
+ %r = select i1 %gt, i8 1, i8 %lt
+ ret i8 %r
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/105272
More information about the llvm-commits
mailing list